Skip to content

FastAPI Dependency Injection

Overview

Dependency Injection is one of FastAPI's core features, providing an elegant way to manage application dependencies. Through dependency injection, we can achieve code decoupling, improve testability, and enhance maintainability. This chapter will深入探讨FastAPI's dependency injection system.

🔧 Dependency Injection Basics

Simple Dependencies

python
from fastapi import FastAPI, Depends
from typing import Optional

app = FastAPI()

def get_current_time():
    """Simple dependency function"""
    from datetime import datetime
    return datetime.now()

def get_user_agent(user_agent: Optional[str] = Header(None)):
    """Get User-Agent from request header"""
    return user_agent or "Unknown"

@app.get("/info/")
async def get_info(
    current_time: datetime = Depends(get_current_time),
    user_agent: str = Depends(get_user_agent)
):
    return {
        "current_time": current_time.isoformat(),
        "user_agent": user_agent,
        "message": "Dependency injection example"
    }

# Dependencies can also be async
async def get_database_info():
    """Simulate async database query"""
    await asyncio.sleep(0.1)  # Simulate database delay
    return {"database": "PostgreSQL", "version": "13.7", "status": "connected"}

@app.get("/db-info/")
async def get_db_info(db_info: dict = Depends(get_database_info)):
    return db_info

Dependencies with Parameters

python
def get_query_extractor(query_param: str = "q", default_value: str = ""):
    """Create query parameter extractor factory function"""
    def query_extractor(request: Request):
        return request.query_params.get(query_param, default_value)
    return query_extractor

def get_pagination_params(page: int = Query(1, ge=1), size: int = Query(10, ge=1, le=100)):
    """Pagination parameter dependency"""
    return {
        "page": page,
        "size": size,
        "offset": (page - 1) * size
    }

def get_sorting_params(
    sort_by: str = Query("created_at", description="Sort field"),
    sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort direction")
):
    """Sorting parameter dependency"""
    return {"sort_by": sort_by, "sort_order": sort_order}

@app.get("/posts/")
async def list_posts(
    pagination: dict = Depends(get_pagination_params),
    sorting: dict = Depends(get_sorting_params),
    search_query: str = Depends(get_query_extractor("search", ""))
):
    return {
        "pagination": pagination,
        "sorting": sorting,
        "search_query": search_query,
        "posts": []  # Simulate post list
    }

🏗️ Classes as Dependencies

Simple Class Dependencies

python
class DatabaseConnection:
    def __init__(self):
        self.host = "localhost"
        self.port = 5432
        self.database = "myapp"
        self.connected = False

    async def connect(self):
        """Simulate database connection"""
        await asyncio.sleep(0.1)
        self.connected = True
        return self

    async def disconnect(self):
        """Simulate disconnect"""
        self.connected = False

    async def execute_query(self, query: str):
        """Simulate execute query"""
        if not self.connected:
            raise Exception("Database not connected")
        await asyncio.sleep(0.05)
        return f"Execute query: {query}"

class UserService:
    def __init__(self, db: DatabaseConnection = Depends()):
        self.db = db

    async def get_user_by_id(self, user_id: int):
        """Get user by ID"""
        await self.db.connect()
        result = await self.db.execute_query(f"SELECT * FROM users WHERE id = {user_id}")
        await self.db.disconnect()
        return {"user_id": user_id, "name": f"User {user_id}", "query_result": result}

    async def create_user(self, user_data: dict):
        """Create user"""
        await self.db.connect()
        result = await self.db.execute_query(f"INSERT INTO users VALUES (...)")
        await self.db.disconnect()
        return {"message": "User created successfully", "user_data": user_data}

@app.get("/users/{user_id}")
async def get_user(user_id: int, user_service: UserService = Depends()):
    return await user_service.get_user_by_id(user_id)

@app.post("/users/")
async def create_user(user_data: dict, user_service: UserService = Depends()):
    return await user_service.create_user(user_data)

Classes with Configuration as Dependencies

python
from pydantic import BaseSettings
import httpx

class APISettings(BaseSettings):
    external_api_url: str = "https://api.example.com"
    api_key: str = "default-key"
    timeout: int = 30
    max_retries: int = 3

    class Config:
        env_file = ".env"

class ExternalAPIClient:
    def __init__(self, settings: APISettings = Depends()):
        self.settings = settings
        self.client = None

    async def __aenter__(self):
        """Async context manager enter"""
        self.client = httpx.AsyncClient(
            base_url=self.settings.external_api_url,
            timeout=self.settings.timeout,
            headers={"Authorization": f"Bearer {self.settings.api_key}"}
        )
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit"""
        if self.client:
            await self.client.aclose()

    async def get_data(self, endpoint: str):
        """Get external API data"""
        if not self.client:
            raise Exception("Client not initialized")

        for attempt in range(self.settings.max_retries):
            try:
                response = await self.client.get(endpoint)
                response.raise_for_status()
                return response.json()
            except httpx.RequestError as e:
                if attempt == self.settings.max_retries - 1:
                    raise
                await asyncio.sleep(2 ** attempt)  # Exponential backoff

async def get_api_client():
    """API client dependency provider"""
    settings = APISettings()
    async with ExternalAPIClient(settings) as client:
        yield client

@app.get("/external-data/{endpoint}")
async def get_external_data(
    endpoint: str,
    api_client: ExternalAPIClient = Depends(get_api_client)
):
    try:
        data = await api_client.get_data(endpoint)
        return {"data": data, "source": "external_api"}
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"External API call failed: {str(e)}")

🔄 Sub-dependencies

Multiple Layer Dependencies

python
class ConfigService:
    def __init__(self):
        self.config = {
            "app_name": "MyApp",
            "version": "1.0.0",
            "debug": True,
            "database_url": "postgresql://localhost/myapp"
        }

    def get(self, key: str, default=None):
        return self.config.get(key, default)

class LoggingService:
    def __init__(self, config: ConfigService = Depends()):
        self.config = config
        self.app_name = config.get("app_name")
        self.debug = config.get("debug")

    def log_info(self, message: str):
        level = "DEBUG" if self.debug else "INFO"
        print(f"[{level}] {self.app_name}: {message}")

    def log_error(self, message: str):
        print(f"[ERROR] {self.app_name}: {message}")

class CacheService:
    def __init__(self, config: ConfigService = Depends(), logger: LoggingService = Depends()):
        self.config = config
        self.logger = logger
        self.cache = {}
        self.logger.log_info("Cache service initialized")

    def get(self, key: str):
        value = self.cache.get(key)
        if value:
            self.logger.log_info(f"Cache hit: {key}")
        else:
            self.logger.log_info(f"Cache miss: {key}")
        return value

    def set(self, key: str, value, ttl: int = 300):
        self.cache[key] = value
        self.logger.log_info(f"Cache set: {key}")

class BusinessService:
    def __init__(
        self,
        config: ConfigService = Depends(),
        logger: LoggingService = Depends(),
        cache: CacheService = Depends()
    ):
        self.config = config
        self.logger = logger
        self.cache = cache

    async def get_business_data(self, data_id: str):
        # Try getting from cache
        cached_data = self.cache.get(f"business_data:{data_id}")
        if cached_data:
            return cached_data

        # Simulate business logic
        self.logger.log_info(f"Processing business data: {data_id}")
        await asyncio.sleep(0.1)  # Simulate processing time

        data = {
            "id": data_id,
            "name": f"Business Data {data_id}",
            "processed_at": datetime.now().isoformat()
        }

        # Cache result
        self.cache.set(f"business_data:{data_id}", data)

        return data

@app.get("/business/{data_id}")
async def get_business_data(data_id: str, service: BusinessService = Depends()):
    return await service.get_business_data(data_id)

🔒 Security Dependencies

Authentication Dependencies

python
import jwt
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

security = HTTPBearer()

class User:
    def __init__(self, user_id: int, username: str, email: str, roles: list):
        self.user_id = user_id
        self.username = username
        self.email = email
        self.roles = roles

    def has_role(self, role: str) -> bool:
        return role in self.roles

    def dict(self):
        return {
            "user_id": self.user_id,
            "username": self.username,
            "email": self.email,
            "roles": self.roles
        }

SECRET_KEY = "your-secret-key"

async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
    """Get current authenticated user"""
    token = credentials.credentials

    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
        user_id = payload.get("user_id")
        username = payload.get("username")
        email = payload.get("email")
        roles = payload.get("roles", [])

        if user_id is None:
            raise HTTPException(status_code=401, detail="Invalid token")

        return User(user_id, username, email, roles)

    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(status_code=401, detail="Invalid token")

async def get_admin_user(current_user: User = Depends(get_current_user)):
    """Dependency requiring admin privileges"""
    if not current_user.has_role("admin"):
        raise HTTPException(status_code=403, detail="Admin privileges required")
    return current_user

def require_roles(*required_roles):
    """Dependency factory requiring specific roles"""
    async def role_checker(current_user: User = Depends(get_current_user)):
        if not any(current_user.has_role(role) for role in required_roles):
            raise HTTPException(
                status_code=403,
                detail=f"Requires one of the following roles: {', '.join(required_roles)}"
            )
        return current_user
    return role_checker

# Use authentication dependencies
@app.get("/profile/")
async def get_profile(current_user: User = Depends(get_current_user)):
    return {"profile": current_user.dict()}

@app.get("/admin/users/")
async def list_all_users(admin_user: User = Depends(get_admin_user)):
    return {"users": [], "admin": admin_user.username}

@app.post("/admin/posts/")
async def create_post(
    post_data: dict,
    editor: User = Depends(require_roles("admin", "editor"))
):
    return {"message": "Post created successfully", "created_by": editor.username}

Permission Dependencies

python
from enum import Enum

class Permission(str, Enum):
    READ_POSTS = "read_posts"
    WRITE_POSTS = "write_posts"
    DELETE_POSTS = "delete_posts"
    MANAGE_USERS = "manage_users"
    SYSTEM_ADMIN = "system_admin"

class PermissionChecker:
    def __init__(self, required_permission: Permission):
        self.required_permission = required_permission

    async def __call__(self, current_user: User = Depends(get_current_user)):
        # System admin has all permissions
        if current_user.has_role("system_admin"):
            return current_user

        # Check specific permission
        user_permissions = self._get_user_permissions(current_user)

        if self.required_permission not in user_permissions:
            raise HTTPException(
                status_code=403,
                detail=f"Requires permission: {self.required_permission.value}"
            )

        return current_user

    def _get_user_permissions(self, user: User) -> list[Permission]:
        """Get permission list based on user role"""
        role_permissions = {
            "admin": [Permission.READ_POSTS, Permission.WRITE_POSTS, Permission.DELETE_POSTS, Permission.MANAGE_USERS],
            "editor": [Permission.READ_POSTS, Permission.WRITE_POSTS],
            "user": [Permission.READ_POSTS]
        }

        permissions = []
        for role in user.roles:
            permissions.extend(role_permissions.get(role, []))

        return list(set(permissions))  # Remove duplicates

# Use permission dependencies
@app.get("/posts/")
async def list_posts(user: User = Depends(PermissionChecker(Permission.READ_POSTS))):
    return {"posts": [], "user": user.username}

@app.post("/posts/")
async def create_post(
    post_data: dict,
    user: User = Depends(PermissionChecker(Permission.WRITE_POSTS))
):
    return {"message": "Post created successfully", "author": user.username}

@app.delete("/posts/{post_id}")
async def delete_post(
    post_id: int,
    user: User = Depends(PermissionChecker(Permission.DELETE_POSTS))
):
    return {"message": f"Post {post_id} deleted successfully", "deleted_by": user.username}

🎛️ Dependency Providers and Scopes

Singleton Dependencies

python
class SingletonService:
    _instance = None
    _initialized = False

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self):
        if not self._initialized:
            self.data = {}
            self.counter = 0
            self._initialized = True

    def increment(self):
        self.counter += 1
        return self.counter

    def set_data(self, key: str, value):
        self.data[key] = value

    def get_data(self, key: str):
        return self.data.get(key)

def get_singleton_service():
    """Singleton service provider"""
    return SingletonService()

@app.get("/singleton/increment/")
async def increment_counter(service: SingletonService = Depends(get_singleton_service)):
    count = service.increment()
    return {"count": count}

@app.post("/singleton/data/{key}")
async def set_singleton_data(
    key: str,
    value: str,
    service: SingletonService = Depends(get_singleton_service)
):
    service.set_data(key, value)
    return {"message": f"Set {key} = {value}"}

@app.get("/singleton/data/{key}")
async def get_singleton_data(
    key: str,
    service: SingletonService = Depends(get_singleton_service)
):
    value = service.get_data(key)
    return {"key": key, "value": value}

Scoped Dependencies

python
from contextvars import ContextVar

# Request scope storage
request_id_var: ContextVar[str] = ContextVar('request_id', default='')

class RequestScopedService:
    def __init__(self):
        self.request_id = request_id_var.get()
        self.created_at = datetime.now()
        self.request_data = {}

    def add_data(self, key: str, value):
        self.request_data[key] = value

    def get_summary(self):
        return {
            "request_id": self.request_id,
            "created_at": self.created_at.isoformat(),
            "data_count": len(self.request_data),
            "data": self.request_data
        }

# Request-level dependency cache
request_services = {}

def get_request_scoped_service(request: Request):
    """Request-scoped service provider"""
    request_id = getattr(request.state, 'request_id', 'unknown')

    if request_id not in request_services:
        request_id_var.set(request_id)
        request_services[request_id] = RequestScopedService()

    return request_services[request_id]

@app.middleware("http")
async def add_request_id(request: Request, call_next):
    request_id = str(uuid.uuid4())
    request.state.request_id = request_id
    request_id_var.set(request_id)

    response = await call_next(request)

    # Cleanup request-scoped services
    if request_id in request_services:
        del request_services[request_id]

    return response

@app.post("/request-scoped/data/")
async def add_request_data(
    key: str,
    value: str,
    service: RequestScopedService = Depends(get_request_scoped_service)
):
    service.add_data(key, value)
    return {"message": f"Added data: {key} = {value}"}

@app.get("/request-scoped/summary/")
async def get_request_summary(
    service: RequestScopedService = Depends(get_request_scoped_service)
):
    return service.get_summary()

🧪 Dependency Testing

Dependency Overrides

python
from fastapi.testclient import TestClient

# Testing dependencies
class MockDatabaseConnection:
    def __init__(self):
        self.connected = True

    async def execute_query(self, query: str):
        return f"Mock result for: {query}"

def get_mock_db():
    return MockDatabaseConnection()

# Testing
def test_with_dependency_override():
    # Override dependency
    app.dependency_overrides[DatabaseConnection] = get_mock_db

    with TestClient(app) as client:
        response = client.get("/users/1")
        assert response.status_code == 200
        assert "Mock result" in response.json()["query_result"]

    # Cleanup overrides
    app.dependency_overrides = {}

# Testing specific dependency
def test_user_service():
    mock_db = MockDatabaseConnection()
    user_service = UserService(db=mock_db)

    # Test service logic
    result = await user_service.get_user_by_id(1)
    assert result["user_id"] == 1
    assert "Mock result" in result["query_result"]

Summary

This chapter detailed FastAPI's dependency injection system:

  • Basic Dependencies: Function dependencies, parameter dependencies, async dependencies
  • Class Dependencies: Service classes, configuration classes, multi-layer dependencies
  • Security Dependencies: Authentication, authorization, permission checking
  • Scope Management: Singleton, request-scoped, dependency caching
  • Testing Support: Dependency overrides, mock dependencies

Dependency injection is FastAPI's powerful feature that makes code more modular, testable, and maintainable.

Dependency Injection Best Practices

  • Keep dependencies single responsibility
  • Reasonably design dependency lifecycle
  • Use type hints to enhance code readability
  • Make good use of dependency overrides for testing
  • Avoid circular dependencies
  • Consider performance impact of dependencies

In the next chapter, we will learn FastAPI's exception handling mechanism and understand how to gracefully handle various error situations.

Content is for learning and research only.