FastAPI Dependency Injection
Overview
Dependency Injection is one of FastAPI's core features, providing an elegant way to manage application dependencies. Through dependency injection, we can achieve code decoupling, improve testability, and enhance maintainability. This chapter will深入探讨FastAPI's dependency injection system.
🔧 Dependency Injection Basics
Simple Dependencies
python
from fastapi import FastAPI, Depends
from typing import Optional
app = FastAPI()
def get_current_time():
"""Simple dependency function"""
from datetime import datetime
return datetime.now()
def get_user_agent(user_agent: Optional[str] = Header(None)):
"""Get User-Agent from request header"""
return user_agent or "Unknown"
@app.get("/info/")
async def get_info(
current_time: datetime = Depends(get_current_time),
user_agent: str = Depends(get_user_agent)
):
return {
"current_time": current_time.isoformat(),
"user_agent": user_agent,
"message": "Dependency injection example"
}
# Dependencies can also be async
async def get_database_info():
"""Simulate async database query"""
await asyncio.sleep(0.1) # Simulate database delay
return {"database": "PostgreSQL", "version": "13.7", "status": "connected"}
@app.get("/db-info/")
async def get_db_info(db_info: dict = Depends(get_database_info)):
return db_infoDependencies with Parameters
python
def get_query_extractor(query_param: str = "q", default_value: str = ""):
"""Create query parameter extractor factory function"""
def query_extractor(request: Request):
return request.query_params.get(query_param, default_value)
return query_extractor
def get_pagination_params(page: int = Query(1, ge=1), size: int = Query(10, ge=1, le=100)):
"""Pagination parameter dependency"""
return {
"page": page,
"size": size,
"offset": (page - 1) * size
}
def get_sorting_params(
sort_by: str = Query("created_at", description="Sort field"),
sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort direction")
):
"""Sorting parameter dependency"""
return {"sort_by": sort_by, "sort_order": sort_order}
@app.get("/posts/")
async def list_posts(
pagination: dict = Depends(get_pagination_params),
sorting: dict = Depends(get_sorting_params),
search_query: str = Depends(get_query_extractor("search", ""))
):
return {
"pagination": pagination,
"sorting": sorting,
"search_query": search_query,
"posts": [] # Simulate post list
}🏗️ Classes as Dependencies
Simple Class Dependencies
python
class DatabaseConnection:
def __init__(self):
self.host = "localhost"
self.port = 5432
self.database = "myapp"
self.connected = False
async def connect(self):
"""Simulate database connection"""
await asyncio.sleep(0.1)
self.connected = True
return self
async def disconnect(self):
"""Simulate disconnect"""
self.connected = False
async def execute_query(self, query: str):
"""Simulate execute query"""
if not self.connected:
raise Exception("Database not connected")
await asyncio.sleep(0.05)
return f"Execute query: {query}"
class UserService:
def __init__(self, db: DatabaseConnection = Depends()):
self.db = db
async def get_user_by_id(self, user_id: int):
"""Get user by ID"""
await self.db.connect()
result = await self.db.execute_query(f"SELECT * FROM users WHERE id = {user_id}")
await self.db.disconnect()
return {"user_id": user_id, "name": f"User {user_id}", "query_result": result}
async def create_user(self, user_data: dict):
"""Create user"""
await self.db.connect()
result = await self.db.execute_query(f"INSERT INTO users VALUES (...)")
await self.db.disconnect()
return {"message": "User created successfully", "user_data": user_data}
@app.get("/users/{user_id}")
async def get_user(user_id: int, user_service: UserService = Depends()):
return await user_service.get_user_by_id(user_id)
@app.post("/users/")
async def create_user(user_data: dict, user_service: UserService = Depends()):
return await user_service.create_user(user_data)Classes with Configuration as Dependencies
python
from pydantic import BaseSettings
import httpx
class APISettings(BaseSettings):
external_api_url: str = "https://api.example.com"
api_key: str = "default-key"
timeout: int = 30
max_retries: int = 3
class Config:
env_file = ".env"
class ExternalAPIClient:
def __init__(self, settings: APISettings = Depends()):
self.settings = settings
self.client = None
async def __aenter__(self):
"""Async context manager enter"""
self.client = httpx.AsyncClient(
base_url=self.settings.external_api_url,
timeout=self.settings.timeout,
headers={"Authorization": f"Bearer {self.settings.api_key}"}
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
if self.client:
await self.client.aclose()
async def get_data(self, endpoint: str):
"""Get external API data"""
if not self.client:
raise Exception("Client not initialized")
for attempt in range(self.settings.max_retries):
try:
response = await self.client.get(endpoint)
response.raise_for_status()
return response.json()
except httpx.RequestError as e:
if attempt == self.settings.max_retries - 1:
raise
await asyncio.sleep(2 ** attempt) # Exponential backoff
async def get_api_client():
"""API client dependency provider"""
settings = APISettings()
async with ExternalAPIClient(settings) as client:
yield client
@app.get("/external-data/{endpoint}")
async def get_external_data(
endpoint: str,
api_client: ExternalAPIClient = Depends(get_api_client)
):
try:
data = await api_client.get_data(endpoint)
return {"data": data, "source": "external_api"}
except Exception as e:
raise HTTPException(status_code=503, detail=f"External API call failed: {str(e)}")🔄 Sub-dependencies
Multiple Layer Dependencies
python
class ConfigService:
def __init__(self):
self.config = {
"app_name": "MyApp",
"version": "1.0.0",
"debug": True,
"database_url": "postgresql://localhost/myapp"
}
def get(self, key: str, default=None):
return self.config.get(key, default)
class LoggingService:
def __init__(self, config: ConfigService = Depends()):
self.config = config
self.app_name = config.get("app_name")
self.debug = config.get("debug")
def log_info(self, message: str):
level = "DEBUG" if self.debug else "INFO"
print(f"[{level}] {self.app_name}: {message}")
def log_error(self, message: str):
print(f"[ERROR] {self.app_name}: {message}")
class CacheService:
def __init__(self, config: ConfigService = Depends(), logger: LoggingService = Depends()):
self.config = config
self.logger = logger
self.cache = {}
self.logger.log_info("Cache service initialized")
def get(self, key: str):
value = self.cache.get(key)
if value:
self.logger.log_info(f"Cache hit: {key}")
else:
self.logger.log_info(f"Cache miss: {key}")
return value
def set(self, key: str, value, ttl: int = 300):
self.cache[key] = value
self.logger.log_info(f"Cache set: {key}")
class BusinessService:
def __init__(
self,
config: ConfigService = Depends(),
logger: LoggingService = Depends(),
cache: CacheService = Depends()
):
self.config = config
self.logger = logger
self.cache = cache
async def get_business_data(self, data_id: str):
# Try getting from cache
cached_data = self.cache.get(f"business_data:{data_id}")
if cached_data:
return cached_data
# Simulate business logic
self.logger.log_info(f"Processing business data: {data_id}")
await asyncio.sleep(0.1) # Simulate processing time
data = {
"id": data_id,
"name": f"Business Data {data_id}",
"processed_at": datetime.now().isoformat()
}
# Cache result
self.cache.set(f"business_data:{data_id}", data)
return data
@app.get("/business/{data_id}")
async def get_business_data(data_id: str, service: BusinessService = Depends()):
return await service.get_business_data(data_id)🔒 Security Dependencies
Authentication Dependencies
python
import jwt
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
security = HTTPBearer()
class User:
def __init__(self, user_id: int, username: str, email: str, roles: list):
self.user_id = user_id
self.username = username
self.email = email
self.roles = roles
def has_role(self, role: str) -> bool:
return role in self.roles
def dict(self):
return {
"user_id": self.user_id,
"username": self.username,
"email": self.email,
"roles": self.roles
}
SECRET_KEY = "your-secret-key"
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Get current authenticated user"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("user_id")
username = payload.get("username")
email = payload.get("email")
roles = payload.get("roles", [])
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid token")
return User(user_id, username, email, roles)
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
async def get_admin_user(current_user: User = Depends(get_current_user)):
"""Dependency requiring admin privileges"""
if not current_user.has_role("admin"):
raise HTTPException(status_code=403, detail="Admin privileges required")
return current_user
def require_roles(*required_roles):
"""Dependency factory requiring specific roles"""
async def role_checker(current_user: User = Depends(get_current_user)):
if not any(current_user.has_role(role) for role in required_roles):
raise HTTPException(
status_code=403,
detail=f"Requires one of the following roles: {', '.join(required_roles)}"
)
return current_user
return role_checker
# Use authentication dependencies
@app.get("/profile/")
async def get_profile(current_user: User = Depends(get_current_user)):
return {"profile": current_user.dict()}
@app.get("/admin/users/")
async def list_all_users(admin_user: User = Depends(get_admin_user)):
return {"users": [], "admin": admin_user.username}
@app.post("/admin/posts/")
async def create_post(
post_data: dict,
editor: User = Depends(require_roles("admin", "editor"))
):
return {"message": "Post created successfully", "created_by": editor.username}Permission Dependencies
python
from enum import Enum
class Permission(str, Enum):
READ_POSTS = "read_posts"
WRITE_POSTS = "write_posts"
DELETE_POSTS = "delete_posts"
MANAGE_USERS = "manage_users"
SYSTEM_ADMIN = "system_admin"
class PermissionChecker:
def __init__(self, required_permission: Permission):
self.required_permission = required_permission
async def __call__(self, current_user: User = Depends(get_current_user)):
# System admin has all permissions
if current_user.has_role("system_admin"):
return current_user
# Check specific permission
user_permissions = self._get_user_permissions(current_user)
if self.required_permission not in user_permissions:
raise HTTPException(
status_code=403,
detail=f"Requires permission: {self.required_permission.value}"
)
return current_user
def _get_user_permissions(self, user: User) -> list[Permission]:
"""Get permission list based on user role"""
role_permissions = {
"admin": [Permission.READ_POSTS, Permission.WRITE_POSTS, Permission.DELETE_POSTS, Permission.MANAGE_USERS],
"editor": [Permission.READ_POSTS, Permission.WRITE_POSTS],
"user": [Permission.READ_POSTS]
}
permissions = []
for role in user.roles:
permissions.extend(role_permissions.get(role, []))
return list(set(permissions)) # Remove duplicates
# Use permission dependencies
@app.get("/posts/")
async def list_posts(user: User = Depends(PermissionChecker(Permission.READ_POSTS))):
return {"posts": [], "user": user.username}
@app.post("/posts/")
async def create_post(
post_data: dict,
user: User = Depends(PermissionChecker(Permission.WRITE_POSTS))
):
return {"message": "Post created successfully", "author": user.username}
@app.delete("/posts/{post_id}")
async def delete_post(
post_id: int,
user: User = Depends(PermissionChecker(Permission.DELETE_POSTS))
):
return {"message": f"Post {post_id} deleted successfully", "deleted_by": user.username}🎛️ Dependency Providers and Scopes
Singleton Dependencies
python
class SingletonService:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.data = {}
self.counter = 0
self._initialized = True
def increment(self):
self.counter += 1
return self.counter
def set_data(self, key: str, value):
self.data[key] = value
def get_data(self, key: str):
return self.data.get(key)
def get_singleton_service():
"""Singleton service provider"""
return SingletonService()
@app.get("/singleton/increment/")
async def increment_counter(service: SingletonService = Depends(get_singleton_service)):
count = service.increment()
return {"count": count}
@app.post("/singleton/data/{key}")
async def set_singleton_data(
key: str,
value: str,
service: SingletonService = Depends(get_singleton_service)
):
service.set_data(key, value)
return {"message": f"Set {key} = {value}"}
@app.get("/singleton/data/{key}")
async def get_singleton_data(
key: str,
service: SingletonService = Depends(get_singleton_service)
):
value = service.get_data(key)
return {"key": key, "value": value}Scoped Dependencies
python
from contextvars import ContextVar
# Request scope storage
request_id_var: ContextVar[str] = ContextVar('request_id', default='')
class RequestScopedService:
def __init__(self):
self.request_id = request_id_var.get()
self.created_at = datetime.now()
self.request_data = {}
def add_data(self, key: str, value):
self.request_data[key] = value
def get_summary(self):
return {
"request_id": self.request_id,
"created_at": self.created_at.isoformat(),
"data_count": len(self.request_data),
"data": self.request_data
}
# Request-level dependency cache
request_services = {}
def get_request_scoped_service(request: Request):
"""Request-scoped service provider"""
request_id = getattr(request.state, 'request_id', 'unknown')
if request_id not in request_services:
request_id_var.set(request_id)
request_services[request_id] = RequestScopedService()
return request_services[request_id]
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
request_id_var.set(request_id)
response = await call_next(request)
# Cleanup request-scoped services
if request_id in request_services:
del request_services[request_id]
return response
@app.post("/request-scoped/data/")
async def add_request_data(
key: str,
value: str,
service: RequestScopedService = Depends(get_request_scoped_service)
):
service.add_data(key, value)
return {"message": f"Added data: {key} = {value}"}
@app.get("/request-scoped/summary/")
async def get_request_summary(
service: RequestScopedService = Depends(get_request_scoped_service)
):
return service.get_summary()🧪 Dependency Testing
Dependency Overrides
python
from fastapi.testclient import TestClient
# Testing dependencies
class MockDatabaseConnection:
def __init__(self):
self.connected = True
async def execute_query(self, query: str):
return f"Mock result for: {query}"
def get_mock_db():
return MockDatabaseConnection()
# Testing
def test_with_dependency_override():
# Override dependency
app.dependency_overrides[DatabaseConnection] = get_mock_db
with TestClient(app) as client:
response = client.get("/users/1")
assert response.status_code == 200
assert "Mock result" in response.json()["query_result"]
# Cleanup overrides
app.dependency_overrides = {}
# Testing specific dependency
def test_user_service():
mock_db = MockDatabaseConnection()
user_service = UserService(db=mock_db)
# Test service logic
result = await user_service.get_user_by_id(1)
assert result["user_id"] == 1
assert "Mock result" in result["query_result"]Summary
This chapter detailed FastAPI's dependency injection system:
- ✅ Basic Dependencies: Function dependencies, parameter dependencies, async dependencies
- ✅ Class Dependencies: Service classes, configuration classes, multi-layer dependencies
- ✅ Security Dependencies: Authentication, authorization, permission checking
- ✅ Scope Management: Singleton, request-scoped, dependency caching
- ✅ Testing Support: Dependency overrides, mock dependencies
Dependency injection is FastAPI's powerful feature that makes code more modular, testable, and maintainable.
Dependency Injection Best Practices
- Keep dependencies single responsibility
- Reasonably design dependency lifecycle
- Use type hints to enhance code readability
- Make good use of dependency overrides for testing
- Avoid circular dependencies
- Consider performance impact of dependencies
In the next chapter, we will learn FastAPI's exception handling mechanism and understand how to gracefully handle various error situations.