Skip to content

PyTorch Model Deployment

Deployment Overview

Model deployment is a crucial step in putting trained deep learning models into production environments. PyTorch provides multiple deployment solutions, from simple scripting deployment to high-performance production-grade services.

python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.jit import script, trace
import onnx
import tensorrt as trt

Model Serialization and Saving

1. Standard Model Saving

python
# Save complete model (not recommended for production)
torch.save(model, 'model_complete.pth')

# Save model state dict (recommended)
torch.save(model.state_dict(), 'model_weights.pth')

# Save training checkpoint
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'config': model_config
}
torch.save(checkpoint, 'checkpoint.pth')

# Load model
def load_model(model_class, weights_path, config):
    model = model_class(config)
    model.load_state_dict(torch.load(weights_path, map_location='cpu'))
    model.eval()
    return model

2. TorchScript Deployment

python
class ImageClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Method 1: Scripting
model = ImageClassifier()
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')

# Method 2: Tracing
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model_traced.pt')

# Load TorchScript model
loaded_model = torch.jit.load('model_scripted.pt')
loaded_model.eval()

# Inference
with torch.no_grad():
    output = loaded_model(example_input)
    print(f"Output shape: {output.shape}")

3. ONNX Export

python
import onnx
import onnxruntime as ort

def export_to_onnx(model, example_input, onnx_path):
    """Export model to ONNX format"""
    model.eval()
    
    torch.onnx.export(
        model,                          # Model
        example_input,                  # Example input
        onnx_path,                      # Output path
        export_params=True,             # Export parameters
        opset_version=11,               # ONNX operator set version
        do_constant_folding=True,       # Constant folding optimization
        input_names=['input'],          # Input names
        output_names=['output'],        # Output names
        dynamic_axes={                  # Dynamic axes
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )
    
    # Verify ONNX model
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print(f"ONNX model saved to: {onnx_path}")

# Inference with ONNX Runtime
def onnx_inference(onnx_path, input_data):
    """Inference using ONNX Runtime"""
    session = ort.InferenceSession(onnx_path)
    
    # Get input/output info
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    # Inference
    result = session.run([output_name], {input_name: input_data.numpy()})
    return result[0]

# Export and test
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
export_to_onnx(model, example_input, 'model.onnx')

# ONNX inference test
onnx_output = onnx_inference('model.onnx', example_input)
print(f"ONNX inference result shape: {onnx_output.shape}")

Web Service Deployment

1. Flask Deployment

python
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64

app = Flask(__name__)

# Global variables
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']

def load_model():
    """Load model"""
    global model, transform
    
    # Load model
    model = torch.jit.load('model_scripted.pt', map_location='cpu')
    model.eval()
    
    # Define preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

def preprocess_image(image_bytes):
    """Preprocess image"""
    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)
    return image_tensor

@app.route('/predict', methods=['POST'])
def predict():
    """Prediction endpoint"""
    try:
        # Get image data
        if 'image' not in request.files:
            return jsonify({'error': 'No image uploaded'}), 400
        
        image_file = request.files['image']
        image_bytes = image_file.read()
        
        # Preprocess
        input_tensor = preprocess_image(image_bytes)
        
        # Inference
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            
            # Get top-5 predictions
            top5_prob, top5_idx = torch.topk(probabilities, 5)
            
            results = []
            for i in range(5):
                results.append({
                    'class': class_names[top5_idx[i].item()],
                    'probability': top5_prob[i].item()
                })
        
        return jsonify({
            'success': True,
            'predictions': results
        })
    
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    """Health check"""
    return jsonify({'status': 'healthy'})

if __name__ == '__main__':
    load_model()
    app.run(host='0.0.0.0', port=5000, debug=False)

2. FastAPI Deployment

python
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
from typing import List
import uvicorn

app = FastAPI(title="PyTorch Model API", version="1.0.0")

# Global variables
model = None
transform = None
class_names = ['cat', 'dog', 'bird', 'fish', 'horse']

@app.on_event("startup")
async def load_model():
    """Load model on startup"""
    global model, transform
    
    model = torch.jit.load('model_scripted.pt', map_location='cpu')
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

class PredictionResponse:
    def __init__(self, class_name: str, probability: float):
        self.class_name = class_name
        self.probability = probability

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """Prediction endpoint"""
    try:
        # Verify file type
        if not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image format")
        
        # Read and preprocess image
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        input_tensor = transform(image).unsqueeze(0)
        
        # Inference
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            
            # Get top-5 predictions
            top5_prob, top5_idx = torch.topk(probabilities, 5)
            
            predictions = []
            for i in range(5):
                predictions.append({
                    "class_name": class_names[top5_idx[i].item()],
                    "probability": float(top5_prob[i].item())
                })
        
        return {"predictions": predictions}
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """Health check"""
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Container Deployment

1. Docker Deployment

dockerfile
# Dockerfile
FROM python:3.9-slim

# Set working directory
WORKDIR /app

# Install system dependencies
RUN apt-get update && apt-get install -y \
    libglib2.0-0 \
    libsm6 \
    libxext6 \
    libxrender-dev \
    libgomp1 \
    && rm -rf /var/lib/apt/lists/*

# Copy requirements file
COPY requirements.txt .

# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt

# Copy application code
COPY . .

# Expose port
EXPOSE 8000

# Startup command
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
yaml
# docker-compose.yml
version: '3.8'

services:
  pytorch-api:
    build: .
    ports:
      - "8000:8000"
    volumes:
      - ./models:/app/models
    environment:
      - MODEL_PATH=/app/models/model_scripted.pt
    restart: unless-stopped
    
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - pytorch-api
    restart: unless-stopped

2. Kubernetes Deployment

yaml
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: pytorch-model-deployment
spec:
  replicas: 3
  selector:
    matchLabels:
      app: pytorch-model
  template:
    metadata:
      labels:
        app: pytorch-model
    spec:
      containers:
      - name: pytorch-api
        image: pytorch-model:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            memory: "512Mi"
            cpu: "500m"
          limits:
            memory: "1Gi"
            cpu: "1000m"
        env:
        - name: MODEL_PATH
          value: "/app/models/model_scripted.pt"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5

---
apiVersion: v1
kind: Service
metadata:
  name: pytorch-model-service
spec:
  selector:
    app: pytorch-model
  ports:
  - protocol: TCP
    port: 80
    targetPort: 8000
  type: LoadBalancer

High-Performance Inference

1. Batch Inference Optimization

python
class BatchInferenceOptimizer:
    def __init__(self, model, max_batch_size=32, timeout=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.timeout = timeout
        self.batch_queue = []
        self.result_futures = []
    
    async def predict(self, input_data):
        """Async batch inference"""
        import asyncio
        from concurrent.futures import Future
        
        future = Future()
        self.batch_queue.append((input_data, future))
        
        # If batch size reached or timeout, perform inference
        if len(self.batch_queue) >= self.max_batch_size:
            await self._process_batch()
        else:
            # Set timeout handling
            asyncio.create_task(self._timeout_handler())
        
        return await asyncio.wrap_future(future)
    
    async def _process_batch(self):
        """Process batch data"""
        if not self.batch_queue:
            return
        
        # Collect batch data
        batch_data = []
        futures = []
        
        for data, future in self.batch_queue:
            batch_data.append(data)
            futures.append(future)
        
        self.batch_queue.clear()
        
        # Batch inference
        try:
            batch_input = torch.stack(batch_data)
            with torch.no_grad():
                batch_output = self.model(batch_input)
            
            # Distribute results
            for i, future in enumerate(futures):
                future.set_result(batch_output[i])
        
        except Exception as e:
            # Error handling
            for future in futures:
                future.set_exception(e)
    
    async def _timeout_handler(self):
        """Timeout handling"""
        import asyncio
        await asyncio.sleep(self.timeout)
        if self.batch_queue:
            await self._process_batch()

2. TensorRT Optimization

python
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

class TensorRTInference:
    def __init__(self, onnx_path, trt_path=None):
        self.onnx_path = onnx_path
        self.trt_path = trt_path or onnx_path.replace('.onnx', '.trt')
        
        # Build TensorRT engine
        self.engine = self._build_engine()
        self.context = self.engine.create_execution_context()
        
        # Allocate GPU memory
        self._allocate_buffers()
    
    def _build_engine(self):
        """Build TensorRT engine"""
        logger = trt.Logger(trt.Logger.WARNING)
        builder = trt.Builder(logger)
        network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
        parser = trt.OnnxParser(network, logger)
        
        # Parse ONNX model
        with open(self.onnx_path, 'rb') as model:
            if not parser.parse(model.read()):
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None
        
        # Configure builder
        config = builder.create_builder_config()
        config.max_workspace_size = 1 << 30  # 1GB
        config.set_flag(trt.BuilderFlag.FP16)  # Enable FP16
        
        # Build engine
        engine = builder.build_engine(network, config)
        
        # Save engine
        with open(self.trt_path, 'wb') as f:
            f.write(engine.serialize())
        
        return engine
    
    def _allocate_buffers(self):
        """Allocate GPU memory buffers"""
        self.inputs = []
        self.outputs = []
        self.bindings = []
        
        for binding in self.engine:
            size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))
            
            # Allocate host and device memory
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            
            self.bindings.append(int(device_mem))
            
            if self.engine.binding_is_input(binding):
                self.inputs.append({'host': host_mem, 'device': device_mem})
            else:
                self.outputs.append({'host': host_mem, 'device': device_mem})
    
    def infer(self, input_data):
        """TensorRT inference"""
        # Copy input data to GPU
        np.copyto(self.inputs[0]['host'], input_data.ravel())
        cuda.memcpy_htod(self.inputs[0]['device'], self.inputs[0]['host'])
        
        # Execute inference
        self.context.execute_v2(bindings=self.bindings)
        
        # Copy output data to CPU
        cuda.memcpy_dtoh(self.outputs[0]['host'], self.outputs[0]['device'])
        
        return self.outputs[0]['host']

Edge Device Deployment

1. Mobile Deployment (PyTorch Mobile)

python
# Model optimization for mobile
def optimize_for_mobile(model, example_input):
    """Optimize model for mobile deployment"""
    model.eval()
    
    # Trace model
    traced_model = torch.jit.trace(model, example_input)
    
    # Mobile optimization
    from torch.utils.mobile_optimizer import optimize_for_mobile
    optimized_model = optimize_for_mobile(traced_model)
    
    # Save optimized model
    optimized_model._save_for_lite_interpreter("model_mobile.ptl")
    
    return optimized_model

# Usage example
model = ImageClassifier()
example_input = torch.randn(1, 3, 224, 224)
mobile_model = optimize_for_mobile(model, example_input)

2. Quantization Deployment

python
def quantize_model_for_deployment(model, calibration_loader):
    """Quantize model for deployment"""
    # Set quantization configuration
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    # Prepare quantization
    model_prepared = torch.quantization.prepare(model, inplace=False)
    
    # Calibrate
    model_prepared.eval()
    with torch.no_grad():
        for data, _ in calibration_loader:
            model_prepared(data)
    
    # Convert to quantized model
    model_quantized = torch.quantization.convert(model_prepared, inplace=False)
    
    return model_quantized

# Dynamic quantization (simpler approach)
def dynamic_quantize_model(model):
    """Dynamically quantize model"""
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    return quantized_model

Monitoring and Logging

1. Performance Monitoring

python
import time
import psutil
import logging
from functools import wraps

class PerformanceMonitor:
    def __init__(self):
        self.metrics = {
            'request_count': 0,
            'total_inference_time': 0,
            'average_inference_time': 0,
            'memory_usage': 0,
            'cpu_usage': 0
        }
    
    def log_inference_time(self, func):
        """Decorator: log inference time"""
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            result = func(*args, **kwargs)
            end_time = time.time()
            
            inference_time = end_time - start_time
            self.metrics['request_count'] += 1
            self.metrics['total_inference_time'] += inference_time
            self.metrics['average_inference_time'] = (
                self.metrics['total_inference_time'] / self.metrics['request_count']
            )
            
            logging.info(f"Inference time: {inference_time:.4f}s")
            return result
        return wrapper
    
    def update_system_metrics(self):
        """Update system metrics"""
        self.metrics['memory_usage'] = psutil.virtual_memory().percent
        self.metrics['cpu_usage'] = psutil.cpu_percent()
    
    def get_metrics(self):
        """Get all metrics"""
        self.update_system_metrics()
        return self.metrics

# Usage example
monitor = PerformanceMonitor()

@monitor.log_inference_time
def model_inference(input_data):
    with torch.no_grad():
        return model(input_data)

2. Logging Configuration

python
import logging
import json
from datetime import datetime

class ModelLogger:
    def __init__(self, log_file='model_service.log'):
        self.logger = logging.getLogger('ModelService')
        self.logger.setLevel(logging.INFO)
        
        # File handler
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        
        # Console handler
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        
        # Formatter
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)
        
        self.logger.addHandler(file_handler)
        self.logger.addHandler(console_handler)
    
    def log_prediction(self, input_info, prediction, confidence, inference_time):
        """Log prediction"""
        log_data = {
            'timestamp': datetime.now().isoformat(),
            'input_info': input_info,
            'prediction': prediction,
            'confidence': confidence,
            'inference_time': inference_time
        }
        
        self.logger.info(f"Prediction completed: {json.dumps(log_data, ensure_ascii=False)}")
    
    def log_error(self, error_msg, input_info=None):
        """Log error"""
        log_data = {
            'timestamp': datetime.now().isoformat(),
            'error': error_msg,
            'input_info': input_info
        }
        
        self.logger.error(f"Prediction error: {json.dumps(log_data, ensure_ascii=False)}")

Summary

PyTorch model deployment covers the complete process from development to production:

  1. Model Serialization: TorchScript, ONNX, and other format conversions
  2. Web Services: API services using Flask, FastAPI, etc.
  3. Containerization: Docker, Kubernetes, and other container deployments
  4. Performance Optimization: Batch inference, TensorRT, and other acceleration techniques
  5. Edge Deployment: Mobile, quantization, and other lightweight solutions
  6. Monitoring and Logging: Performance monitoring and complete logging system

Mastering these deployment techniques will help you successfully put PyTorch models into production environments!

Content is for learning and research only.