PyTorch Model Deployment
Deployment Overview
Model deployment is a crucial step in putting trained deep learning models into production environments. PyTorch provides multiple deployment solutions, from simple scripting deployment to high-performance production-grade services.
python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.jit import script, trace
import onnx
import tensorrt as trtModel Serialization and Saving
1. Standard Model Saving
python
# Save complete model (not recommended for production)
torch.save(model, 'model_complete.pth')
# Save model state dict (recommended)
torch.save(model.state_dict(), 'model_weights.pth')
# Save training checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'config': model_config
}
torch.save(checkpoint, 'checkpoint.pth')
# Load model
def load_model(model_class, weights_path, config):
model = model_class(config)
model.load_state_dict(torch.load(weights_path, map_location='cpu'))
model.eval()
return model2. TorchScript Deployment
python
class ImageClassifier(nn.Module):
def __init__(self, num_classes=10):
super(ImageClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Method 1: Scripting
model = ImageClassifier()
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
# Method 2: Tracing
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model_traced.pt')
# Load TorchScript model
loaded_model = torch.jit.load('model_scripted.pt')
loaded_model.eval()
# Inference
with torch.no_grad():
output = loaded_model(example_input)
print(f"Output shape: {output.shape}")3. ONNX Export
python
import onnx
import onnxruntime as ort
def export_to_onnx(model, example_input, onnx_path):
"""Export model to ONNX format"""
model.eval()
torch.onnx.export(
model, # Model
example_input, # Example input
onnx_path, # Output path
export_params=True, # Export parameters
opset_version=11, # ONNX operator set version
do_constant_folding=True, # Constant folding optimization
input_names=['input'], # Input names
output_names=['output'], # Output names
dynamic_axes={ # Dynamic axes
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# Verify ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print(f"ONNX model saved to: {onnx_path}")
# Inference with ONNX Runtime
def onnx_inference(onnx_path, input_data):
"""Inference using ONNX Runtime"""
session = ort.InferenceSession(onnx_path)
# Get input/output info
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Inference
result = session.run([output_name], {input_name: input_data.numpy()})
return result[0]
# Export and test
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
export_to_onnx(model, example_input, 'model.onnx')
# ONNX inference test
onnx_output = onnx_inference('model.onnx', example_input)
print(f"ONNX inference result shape: {onnx_output.shape}")Web Service Deployment
1. Flask Deployment
python
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64
app = Flask(__name__)
# Global variables
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']
def load_model():
"""Load model"""
global model, transform
# Load model
model = torch.jit.load('model_scripted.pt', map_location='cpu')
model.eval()
# Define preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def preprocess_image(image_bytes):
"""Preprocess image"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image_tensor = transform(image).unsqueeze(0)
return image_tensor
@app.route('/predict', methods=['POST'])
def predict():
"""Prediction endpoint"""
try:
# Get image data
if 'image' not in request.files:
return jsonify({'error': 'No image uploaded'}), 400
image_file = request.files['image']
image_bytes = image_file.read()
# Preprocess
input_tensor = preprocess_image(image_bytes)
# Inference
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top-5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
results = []
for i in range(5):
results.append({
'class': class_names[top5_idx[i].item()],
'probability': top5_prob[i].item()
})
return jsonify({
'success': True,
'predictions': results
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""Health check"""
return jsonify({'status': 'healthy'})
if __name__ == '__main__':
load_model()
app.run(host='0.0.0.0', port=5000, debug=False)2. FastAPI Deployment
python
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
from typing import List
import uvicorn
app = FastAPI(title="PyTorch Model API", version="1.0.0")
# Global variables
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']
@app.on_event("startup")
async def load_model():
"""Load model on startup"""
global model, transform
model = torch.jit.load('model_scripted.pt', map_location='cpu')
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
class PredictionResponse:
def __init__(self, class_name: str, probability: float):
self.class_name = class_name
self.probability = probability
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""Prediction endpoint"""
try:
# Verify file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image format")
# Read and preprocess image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
input_tensor = transform(image).unsqueeze(0)
# Inference
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Get top-5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
predictions = []
for i in range(5):
predictions.append({
"class_name": class_names[top5_idx[i].item()],
"probability": float(top5_prob[i].item())
})
return {"predictions": predictions}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check"""
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)Container Deployment
1. Docker Deployment
dockerfile
# Dockerfile
FROM python:3.9-slim
# Set working directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libgomp1 \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements file
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose port
EXPOSE 8000
# Startup command
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]yaml
# docker-compose.yml
version: '3.8'
services:
pytorch-api:
build: .
ports:
- "8000:8000"
volumes:
- ./models:/app/models
environment:
- MODEL_PATH=/app/models/model_scripted.pt
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- pytorch-api
restart: unless-stopped2. Kubernetes Deployment
yaml
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: pytorch-model-deployment
spec:
replicas: 3
selector:
matchLabels:
app: pytorch-model
template:
metadata:
labels:
app: pytorch-model
spec:
containers:
- name: pytorch-api
image: pytorch-model:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "512Mi"
cpu: "500m"
limits:
memory: "1Gi"
cpu: "1000m"
env:
- name: MODEL_PATH
value: "/app/models/model_scripted.pt"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: pytorch-model-service
spec:
selector:
app: pytorch-model
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancerHigh-Performance Inference
1. Batch Inference Optimization
python
class BatchInferenceOptimizer:
def __init__(self, model, max_batch_size=32, timeout=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.timeout = timeout
self.batch_queue = []
self.result_futures = []
async def predict(self, input_data):
"""Async batch inference"""
import asyncio
from concurrent.futures import Future
future = Future()
self.batch_queue.append((input_data, future))
# If batch size reached or timeout, perform inference
if len(self.batch_queue) >= self.max_batch_size:
await self._process_batch()
else:
# Set timeout handling
asyncio.create_task(self._timeout_handler())
return await asyncio.wrap_future(future)
async def _process_batch(self):
"""Process batch data"""
if not self.batch_queue:
return
# Collect batch data
batch_data = []
futures = []
for data, future in self.batch_queue:
batch_data.append(data)
futures.append(future)
self.batch_queue.clear()
# Batch inference
try:
batch_input = torch.stack(batch_data)
with torch.no_grad():
batch_output = self.model(batch_input)
# Distribute results
for i, future in enumerate(futures):
future.set_result(batch_output[i])
except Exception as e:
# Error handling
for future in futures:
future.set_exception(e)
async def _timeout_handler(self):
"""Timeout handling"""
import asyncio
await asyncio.sleep(self.timeout)
if self.batch_queue:
await self._process_batch()2. TensorRT Optimization
python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class TensorRTInference:
def __init__(self, onnx_path, trt_path=None):
self.onnx_path = onnx_path
self.trt_path = trt_path or onnx_path.replace('.onnx', '.trt')
# Build TensorRT engine
self.engine = self._build_engine()
self.context = self.engine.create_execution_context()
# Allocate GPU memory
self._allocate_buffers()
def _build_engine(self):
"""Build TensorRT engine"""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# Parse ONNX model
with open(self.onnx_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# Configure builder
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
config.set_flag(trt.BuilderFlag.FP16) # Enable FP16
# Build engine
engine = builder.build_engine(network, config)
# Save engine
with open(self.trt_path, 'wb') as f:
f.write(engine.serialize())
return engine
def _allocate_buffers(self):
"""Allocate GPU memory buffers"""
self.inputs = []
self.outputs = []
self.bindings = []
for binding in self.engine:
size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
# Allocate host and device memory
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
self.bindings.append(int(device_mem))
if self.engine.binding_is_input(binding):
self.inputs.append({'host': host_mem, 'device': device_mem})
else:
self.outputs.append({'host': host_mem, 'device': device_mem})
def infer(self, input_data):
"""TensorRT inference"""
# Copy input data to GPU
np.copyto(self.inputs[0]['host'], input_data.ravel())
cuda.memcpy_htod(self.inputs[0]['device'], self.inputs[0]['host'])
# Execute inference
self.context.execute_v2(bindings=self.bindings)
# Copy output data to CPU
cuda.memcpy_dtoh(self.outputs[0]['host'], self.outputs[0]['device'])
return self.outputs[0]['host']Edge Device Deployment
1. Mobile Deployment (PyTorch Mobile)
python
# Model optimization for mobile
def optimize_for_mobile(model, example_input):
"""Optimize model for mobile deployment"""
model.eval()
# Trace model
traced_model = torch.jit.trace(model, example_input)
# Mobile optimization
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_model = optimize_for_mobile(traced_model)
# Save optimized model
optimized_model._save_for_lite_interpreter("model_mobile.ptl")
return optimized_model
# Usage example
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
mobile_model = optimize_for_mobile(model, example_input)2. Quantization Deployment
python
def quantize_model_for_deployment(model, calibration_loader):
"""Quantize model for deployment"""
# Set quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# Prepare quantization
model_prepared = torch.quantization.prepare(model, inplace=False)
# Calibrate
model_prepared.eval()
with torch.no_grad():
for data, _ in calibration_loader:
model_prepared(data)
# Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared, inplace=False)
return model_quantized
# Dynamic quantization (simpler approach)
def dynamic_quantize_model(model):
"""Dynamically quantize model"""
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return quantized_modelMonitoring and Logging
1. Performance Monitoring
python
import time
import psutil
import logging
from functools import wraps
class PerformanceMonitor:
def __init__(self):
self.metrics = {
'request_count': 0,
'total_inference_time': 0,
'average_inference_time': 0,
'memory_usage': 0,
'cpu_usage': 0
}
def log_inference_time(self, func):
"""Decorator: log inference time"""
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
inference_time = end_time - start_time
self.metrics['request_count'] += 1
self.metrics['total_inference_time'] += inference_time
self.metrics['average_inference_time'] = (
self.metrics['total_inference_time'] / self.metrics['request_count']
)
logging.info(f"Inference time: {inference_time:.4f}s")
return result
return wrapper
def update_system_metrics(self):
"""Update system metrics"""
self.metrics['memory_usage'] = psutil.virtual_memory().percent
self.metrics['cpu_usage'] = psutil.cpu_percent()
def get_metrics(self):
"""Get all metrics"""
self.update_system_metrics()
return self.metrics
# Usage example
monitor = PerformanceMonitor()
@monitor.log_inference_time
def model_inference(input_data):
with torch.no_grad():
return model(input_data)2. Logging Configuration
python
import logging
import json
from datetime import datetime
class ModelLogger:
def __init__(self, log_file='model_service.log'):
self.logger = logging.getLogger('ModelService')
self.logger.setLevel(logging.INFO)
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# Formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def log_prediction(self, input_info, prediction, confidence, inference_time):
"""Log prediction"""
log_data = {
'timestamp': datetime.now().isoformat(),
'input_info': input_info,
'prediction': prediction,
'confidence': confidence,
'inference_time': inference_time
}
self.logger.info(f"Prediction completed: {json.dumps(log_data, ensure_ascii=False)}")
def log_error(self, error_msg, input_info=None):
"""Log error"""
log_data = {
'timestamp': datetime.now().isoformat(),
'error': error_msg,
'input_info': input_info
}
self.logger.error(f"Prediction error: {json.dumps(log_data, ensure_ascii=False)}")Summary
PyTorch model deployment covers the complete process from development to production:
- Model Serialization: TorchScript, ONNX, and other format conversions
- Web Services: API services using Flask, FastAPI, etc.
- Containerization: Docker, Kubernetes, and other container deployments
- Performance Optimization: Batch inference, TensorRT, and other acceleration techniques
- Edge Deployment: Mobile, quantization, and other lightweight solutions
- Monitoring and Logging: Performance monitoring and complete logging system
Mastering these deployment techniques will help you successfully put PyTorch models into production environments!