Skip to content

PyTorch Text Classification Project

Project Overview

This chapter will demonstrate how to use PyTorch for natural language processing tasks through a complete text classification project. We will build a sentiment analysis system capable of determining the sentiment tendency of text (positive, negative, neutral).

Project Structure

text_classification/
├── data/                   # Data directory
│   ├── raw/               # Raw data
│   ├── processed/         # Processed data
│   └── vocab/             # Vocabulary
├── models/                # Model definitions
│   ├── __init__.py
│   ├── lstm_classifier.py
│   ├── transformer_classifier.py
│   └── cnn_classifier.py
├── utils/                 # Utility functions
│   ├── __init__.py
│   ├── data_loader.py
│   ├── text_processor.py
│   └── metrics.py
├── configs/               # Configuration files
├── train.py              # Training script
├── evaluate.py           # Evaluation script
└── inference.py          # Inference script

Data Preprocessing

1. Text Preprocessor

python
import re
import string
import torch
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import jieba  # Chinese word segmentation
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

class TextPreprocessor:
    def __init__(self, language='en', max_vocab_size=50000, min_freq=2):
        self.language = language
        self.max_vocab_size = max_vocab_size
        self.min_freq = min_freq
        
        # Special tokens
        self.PAD_TOKEN = '<PAD>'
        self.UNK_TOKEN = '<UNK>'
        self.SOS_TOKEN = '<SOS>'
        self.EOS_TOKEN = '<EOS>'
        
        # Vocabulary
        self.vocab = {}
        self.idx2word = {}
        self.word_freq = Counter()
        
        # Stopwords
        if language == 'en':
            try:
                self.stop_words = set(stopwords.words('english'))
            except:
                self.stop_words = set()
        else:
            self.stop_words = set()
    
    def clean_text(self, text: str) -> str:
        """Clean text"""
        # Convert to lowercase
        text = text.lower()
        
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', '', text)
        
        # Remove URLs
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
        
        # Remove email addresses
        text = re.sub(r'\S+@\S+', '', text)
        
        # Remove numbers (optional)
        # text = re.sub(r'\d+', '', text)
        
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def tokenize(self, text: str) -> List[str]:
        """Tokenize"""
        if self.language == 'zh':
            # Chinese word segmentation
            tokens = list(jieba.cut(text))
        else:
            # English tokenization
            tokens = word_tokenize(text)
        
        # Remove punctuation and stopwords
        tokens = [
            token for token in tokens 
            if token not in string.punctuation and token not in self.stop_words
        ]
        
        return tokens
    
    def build_vocab(self, texts: List[str]):
        """Build vocabulary"""
        print("Building vocabulary...")
        
        # Count word frequencies
        for text in texts:
            cleaned_text = self.clean_text(text)
            tokens = self.tokenize(cleaned_text)
            self.word_freq.update(tokens)
        
        # Create vocabulary list
        vocab_list = [self.PAD_TOKEN, self.UNK_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN]
        
        # Sort by frequency, take top max_vocab_size words
        sorted_words = sorted(self.word_freq.items(), key=lambda x: x[1], reverse=True)
        for word, freq in sorted_words:
            if freq >= self.min_freq and len(vocab_list) < self.max_vocab_size:
                vocab_list.append(word)
        
        # Build vocabulary mapping
        self.vocab = {word: idx for idx, word in enumerate(vocab_list)}
        self.idx2word = {idx: word for word, idx in self.vocab.items()}
        
        print(f"Vocabulary size: {len(self.vocab)}")
        print(f"Total word frequency: {sum(self.word_freq.values())}")
    
    def text_to_sequence(self, text: str, max_length: int = None) -> List[int]:
        """Convert text to sequence"""
        cleaned_text = self.clean_text(text)
        tokens = self.tokenize(cleaned_text)
        
        # Convert to indices
        sequence = [
            self.vocab.get(token, self.vocab[self.UNK_TOKEN]) 
            for token in tokens
        ]
        
        # Truncate or pad
        if max_length:
            if len(sequence) > max_length:
                sequence = sequence[:max_length]
            else:
                sequence.extend([self.vocab[self.PAD_TOKEN]] * (max_length - len(sequence)))
        
        return sequence
    
    def sequence_to_text(self, sequence: List[int]) -> str:
        """Convert sequence to text"""
        tokens = [
            self.idx2word.get(idx, self.UNK_TOKEN) 
            for idx in sequence
            if idx != self.vocab[self.PAD_TOKEN]
        ]
        return ' '.join(tokens)
    
    def save_vocab(self, filepath: str):
        """Save vocabulary"""
        import pickle
        vocab_data = {
            'vocab': self.vocab,
            'idx2word': self.idx2word,
            'word_freq': self.word_freq,
            'config': {
                'language': self.language,
                'max_vocab_size': self.max_vocab_size,
                'min_freq': self.min_freq
            }
        }
        with open(filepath, 'wb') as f:
            pickle.dump(vocab_data, f)
    
    def load_vocab(self, filepath: str):
        """Load vocabulary"""
        import pickle
        with open(filepath, 'rb') as f:
            vocab_data = pickle.load(f)
        
        self.vocab = vocab_data['vocab']
        self.idx2word = vocab_data['idx2word']
        self.word_freq = vocab_data['word_freq']
        config = vocab_data['config']
        self.language = config['language']
        self.max_vocab_size = config['max_vocab_size']
        self.min_freq = config['min_freq']

# Usage example
preprocessor = TextPreprocessor(language='en')

# Example texts
texts = [
    "I love this movie! It's absolutely fantastic.",
    "This film is terrible. I hate it.",
    "The movie is okay, not great but not bad either."
]

# Build vocabulary
preprocessor.build_vocab(texts)

# Convert text to sequence
sequence = preprocessor.text_to_sequence(texts[0], max_length=20)
print(f"Original text: {texts[0]}")
print(f"Sequence: {sequence}")
print(f"Recovered: {preprocessor.sequence_to_text(sequence)}")

2. Dataset Class

python
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch

class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, preprocessor, max_length=128):
        self.texts = texts
        self.labels = labels
        self.preprocessor = preprocessor
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Convert text to sequence
        sequence = self.preprocessor.text_to_sequence(text, self.max_length)
        
        return {
            'input_ids': torch.tensor(sequence, dtype=torch.long),
            'label': torch.tensor(label, dtype=torch.long),
            'text': text
        }

def create_data_loaders(train_texts, train_labels, val_texts, val_labels, 
                       preprocessor, batch_size=32, max_length=128):
    """Create data loaders"""
    
    train_dataset = TextClassificationDataset(
        train_texts, train_labels, preprocessor, max_length
    )
    val_dataset = TextClassificationDataset(
        val_texts, val_labels, preprocessor, max_length
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, 
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )
    
    return train_loader, val_loader

# Example data loading
def load_imdb_data():
    """Load IMDB dataset example"""
    # Using example data here, real data needs to be loaded in actual projects
    train_texts = [
        "I love this movie! It's absolutely fantastic.",
        "This film is terrible. I hate it.",
        "The movie is okay, not great but not bad either.",
        "Amazing cinematography and great acting!",
        "Boring and predictable plot."
    ] * 1000  # Expand data
    
    train_labels = [1, 0, 2, 1, 0] * 1000  # 0: negative, 1: positive, 2: neutral
    
    val_texts = train_texts[:500]
    val_labels = train_labels[:500]
    
    return train_texts, train_labels, val_texts, val_labels

Model Architecture

1. LSTM Classifier

python
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, 
                 num_layers=2, dropout=0.3, bidirectional=True):
        super(LSTMClassifier, self).__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        # Word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # LSTM layer
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        
        # Attention mechanism
        lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.attention = nn.Linear(lstm_output_dim, 1)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, input_ids, attention_mask=None):
        # Word embedding
        embedded = self.embedding(input_ids)  # (batch_size, seq_len, embed_dim)
        
        # LSTM
        lstm_out, (hidden, cell) = self.lstm(embedded)  # (batch_size, seq_len, hidden_dim*2)
        
        # Attention mechanism
        if attention_mask is not None:
            # Create attention mask
            attention_weights = self.attention(lstm_out).squeeze(-1)  # (batch_size, seq_len)
            attention_weights = attention_weights.masked_fill(attention_mask == 0, -1e9)
            attention_weights = F.softmax(attention_weights, dim=1)
            
            # Weighted average
            context = torch.sum(attention_weights.unsqueeze(-1) * lstm_out, dim=1)
        else:
            # Simple average pooling
            context = torch.mean(lstm_out, dim=1)
        
        # Classification
        logits = self.classifier(context)
        
        return logits

# Create model
def create_lstm_model(vocab_size, num_classes):
    model = LSTMClassifier(
        vocab_size=vocab_size,
        embed_dim=128,
        hidden_dim=256,
        num_classes=num_classes,
        num_layers=2,
        dropout=0.3,
        bidirectional=True
    )
    return model

2. CNN Classifier

python
class CNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, 
                 num_classes, dropout=0.3):
        super(CNNClassifier, self).__init__()
        
        # Word embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Multiple convolutional layers
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
            for fs in filter_sizes
        ])
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(len(filter_sizes) * num_filters, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, input_ids):
        # Word embedding
        embedded = self.embedding(input_ids)  # (batch_size, seq_len, embed_dim)
        embedded = embedded.transpose(1, 2)   # (batch_size, embed_dim, seq_len)
        
        # Convolution and pooling
        conv_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(embedded))  # (batch_size, num_filters, conv_seq_len)
            pooled = F.max_pool1d(conv_out, conv_out.size(2))  # (batch_size, num_filters, 1)
            conv_outputs.append(pooled.squeeze(2))
        
        # Concatenate all convolution outputs
        concat_output = torch.cat(conv_outputs, dim=1)  # (batch_size, len(filter_sizes) * num_filters)
        
        # Classification
        logits = self.classifier(concat_output)
        
        return logits

def create_cnn_model(vocab_size, num_classes):
    model = CNNClassifier(
        vocab_size=vocab_size,
        embed_dim=128,
        num_filters=100,
        filter_sizes=[3, 4, 5],
        num_classes=num_classes,
        dropout=0.3
    )
    return model

3. Transformer Classifier

python
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, 
                 num_classes, max_length=512, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        
        self.embed_dim = embed_dim
        self.max_length = max_length
        
        # Word embedding and positional encoding
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoding = nn.Parameter(torch.randn(max_length, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes)
        )
    
    def forward(self, input_ids, attention_mask=None):
        seq_len = input_ids.size(1)
        
        # Word embedding + positional encoding
        embedded = self.embedding(input_ids)
        embedded += self.pos_encoding[:seq_len, :].unsqueeze(0)
        
        # Create padding mask
        if attention_mask is None:
            attention_mask = (input_ids != 0)
        
        # Transformer encoding
        transformer_out = self.transformer(
            embedded, 
            src_key_padding_mask=~attention_mask
        )
        
        # Global average pooling (ignore padding positions)
        mask_expanded = attention_mask.unsqueeze(-1).float()
        sum_embeddings = torch.sum(transformer_out * mask_expanded, dim=1)
        sum_mask = torch.sum(mask_expanded, dim=1)
        pooled = sum_embeddings / sum_mask
        
        # Classification
        logits = self.classifier(pooled)
        
        return logits

def create_transformer_model(vocab_size, num_classes):
    model = TransformerClassifier(
        vocab_size=vocab_size,
        embed_dim=256,
        num_heads=8,
        num_layers=6,
        num_classes=num_classes,
        max_length=512,
        dropout=0.1
    )
    return model

Training Framework

1. Trainer

python
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np

class TextClassificationTrainer:
    def __init__(self, model, train_loader, val_loader, num_classes, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_classes = num_classes
        self.device = device
        
        # Loss function and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', patience=3, factor=0.5)
        
        # Training history
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        
        # Best model
        self.best_val_acc = 0.0
    
    def train_epoch(self):
        """Train one epoch"""
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        for batch in self.train_loader:
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Create attention mask
            attention_mask = (input_ids != 0)
            
            self.optimizer.zero_grad()
            
            # Forward pass
            if isinstance(self.model, TransformerClassifier):
                logits = self.model(input_ids, attention_mask)
            else:
                logits = self.model(input_ids)
            
            loss = self.criterion(logits, labels)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            # Statistics
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        avg_loss = total_loss / len(self.train_loader)
        accuracy = accuracy_score(all_labels, all_preds)
        
        return avg_loss, accuracy
    
    def validate_epoch(self):
        """Validate one epoch"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in self.val_loader:
                input_ids = batch['input_ids'].to(self.device)
                labels = batch['label'].to(self.device)
                
                attention_mask = (input_ids != 0)
                
                if isinstance(self.model, TransformerClassifier):
                    logits = self.model(input_ids, attention_mask)
                else:
                    logits = self.model(input_ids)
                
                loss = self.criterion(logits, labels)
                
                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        avg_loss = total_loss / len(self.val_loader)
        accuracy = accuracy_score(all_labels, all_preds)
        
        return avg_loss, accuracy, all_preds, all_labels
    
    def train(self, num_epochs):
        """Complete training process"""
        print(f"Starting training, {num_epochs} epochs")
        
        for epoch in range(num_epochs):
            # Train
            train_loss, train_acc = self.train_epoch()
            
            # Validate
            val_loss, val_acc, val_preds, val_labels = self.validate_epoch()
            
            # Update learning rate
            self.scheduler.step(val_acc)
            
            # Record history
            self.train_losses.append(train_loss)
            self.train_accs.append(train_acc)
            self.val_losses.append(val_loss)
            self.val_accs.append(val_acc)
            
            # Print results
            print(f'Epoch {epoch+1}/{num_epochs}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            print(f'  LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                torch.save(self.model.state_dict(), 'best_model.pth')
                print(f'  ✓ New best model! Validation accuracy: {val_acc:.4f}')
            
            print('-' * 50)
        
        print(f'Training complete! Best validation accuracy: {self.best_val_acc:.4f}')
        
        return self.train_losses, self.train_accs, self.val_losses, self.val_accs

Model Evaluation

1. Detailed Evaluation

python
def evaluate_model(model, test_loader, device, class_names):
    """Evaluate model in detail"""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            
            attention_mask = (input_ids != 0)
            
            if isinstance(model, TransformerClassifier):
                logits = model(input_ids, attention_mask)
            else:
                logits = model(input_ids)
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=class_names)
    cm = confusion_matrix(all_labels, all_preds)
    
    print(f"Test accuracy: {accuracy:.4f}")
    print("\nClassification report:")
    print(report)
    
    return accuracy, report, cm, all_probs

2. Error Analysis

python
def analyze_errors(model, test_loader, preprocessor, device, class_names, num_examples=10):
    """Analyze incorrect predictions"""
    model.eval()
    errors = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['label'].to(device)
            texts = batch['text']
            
            attention_mask = (input_ids != 0)
            
            if isinstance(model, TransformerClassifier):
                logits = model(input_ids, attention_mask)
            else:
                logits = model(input_ids)
            
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            # Find incorrect predictions
            wrong_mask = preds != labels
            if wrong_mask.any():
                wrong_indices = torch.where(wrong_mask)[0]
                
                for idx in wrong_indices:
                    errors.append({
                        'text': texts[idx],
                        'true_label': class_names[labels[idx].item()],
                        'pred_label': class_names[preds[idx].item()],
                        'confidence': probs[idx].max().item(),
                        'all_probs': probs[idx].cpu().numpy()
                    })
                    
                    if len(errors) >= num_examples:
                        break
            
            if len(errors) >= num_examples:
                break
    
    # Print error analysis
    print("Incorrect Prediction Analysis:")
    print("=" * 80)
    
    for i, error in enumerate(errors):
        print(f"\nExample {i+1}:")
        print(f"Text: {error['text'][:200]}...")
        print(f"True label: {error['true_label']}")
        print(f"Predicted label: {error['pred_label']}")
        print(f"Confidence: {error['confidence']:.4f}")
        
        # Show probabilities for all classes
        for j, prob in enumerate(error['all_probs']):
            print(f"  {class_names[j]}: {prob:.4f}")
    
    return errors

Inference and Application

1. Single Text Inference

python
def predict_single_text(model, text, preprocessor, device, class_names, max_length=128):
    """Predict single text"""
    model.eval()
    
    # Preprocess
    sequence = preprocessor.text_to_sequence(text, max_length)
    input_ids = torch.tensor([sequence], dtype=torch.long).to(device)
    attention_mask = (input_ids != 0)
    
    with torch.no_grad():
        if isinstance(model, TransformerClassifier):
            logits = model(input_ids, attention_mask)
        else:
            logits = model(input_ids)
        
        probs = F.softmax(logits, dim=1)
        pred_class = torch.argmax(logits, dim=1).item()
        confidence = probs[0][pred_class].item()
    
    # Get probabilities for all classes
    all_probs = probs[0].cpu().numpy()
    
    result = {
        'text': text,
        'predicted_class': class_names[pred_class],
        'confidence': confidence,
        'all_probabilities': {
            class_names[i]: float(prob) for i, prob in enumerate(all_probs)
        }
    }
    
    return result

# Usage example
text = "I absolutely love this movie! It's fantastic!"
result = predict_single_text(model, text, preprocessor, device, ['negative', 'positive', 'neutral'])
print(f"Prediction result: {result}")

2. Batch Inference

python
def batch_predict(model, texts, preprocessor, device, class_names, batch_size=32, max_length=128):
    """Batch prediction"""
    model.eval()
    results = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        
        # Preprocess batch data
        sequences = [preprocessor.text_to_sequence(text, max_length) for text in batch_texts]
        input_ids = torch.tensor(sequences, dtype=torch.long).to(device)
        attention_mask = (input_ids != 0)
        
        with torch.no_grad():
            if isinstance(model, TransformerClassifier):
                logits = model(input_ids, attention_mask)
            else:
                logits = model(input_ids)
            
            probs = F.softmax(logits, dim=1)
            pred_classes = torch.argmax(logits, dim=1)
        
        # Process results
        for j, text in enumerate(batch_texts):
            pred_class = pred_classes[j].item()
            confidence = probs[j][pred_class].item()
            
            results.append({
                'text': text,
                'predicted_class': class_names[pred_class],
                'confidence': confidence
            })
    
    return results

Complete Training Script

python
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load data
    train_texts, train_labels, val_texts, val_labels = load_imdb_data()
    
    # Create preprocessor
    preprocessor = TextPreprocessor(language='en')
    preprocessor.build_vocab(train_texts)
    
    # Create data loaders
    train_loader, val_loader = create_data_loaders(
        train_texts, train_labels, val_texts, val_labels,
        preprocessor, batch_size=32, max_length=128
    )
    
    # Create model
    vocab_size = len(preprocessor.vocab)
    num_classes = 3  # negative, positive, neutral
    
    # Select model type
    model = create_lstm_model(vocab_size, num_classes)
    # model = create_cnn_model(vocab_size, num_classes)
    # model = create_transformer_model(vocab_size, num_classes)
    
    print(f"Model parameter count: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create trainer
    trainer = TextClassificationTrainer(model, train_loader, val_loader, num_classes, device)
    
    # Train model
    train_losses, train_accs, val_losses, val_accs = trainer.train(num_epochs=20)
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pth'))
    
    # Evaluate model
    class_names = ['negative', 'positive', 'neutral']
    accuracy, report, cm, probs = evaluate_model(model, val_loader, device, class_names)
    
    # Error analysis
    errors = analyze_errors(model, val_loader, preprocessor, device, class_names)
    
    # Save model and preprocessor
    torch.save(model.state_dict(), 'final_model.pth')
    preprocessor.save_vocab('vocab.pkl')
    
    print("Training complete!")

if __name__ == '__main__':
    main()

Summary

This chapter demonstrated through a complete text classification project:

  1. Text Preprocessing: Complete workflow of cleaning, tokenization, vocabulary building
  2. Model Architectures: Different text classification models including LSTM, CNN, Transformer
  3. Training Framework: Complete training, validation, and saving process
  4. Model Evaluation: Evaluation methods including accuracy, classification report, error analysis
  5. Practical Applications: Implementation of single text and batch inference

This project template can be applied to various text classification tasks such as sentiment analysis, topic classification, spam detection, etc.

Content is for learning and research only.