PyTorch Text Classification Project
Project Overview
This chapter will demonstrate how to use PyTorch for natural language processing tasks through a complete text classification project. We will build a sentiment analysis system capable of determining the sentiment tendency of text (positive, negative, neutral).
Project Structure
text_classification/
├── data/ # Data directory
│ ├── raw/ # Raw data
│ ├── processed/ # Processed data
│ └── vocab/ # Vocabulary
├── models/ # Model definitions
│ ├── __init__.py
│ ├── lstm_classifier.py
│ ├── transformer_classifier.py
│ └── cnn_classifier.py
├── utils/ # Utility functions
│ ├── __init__.py
│ ├── data_loader.py
│ ├── text_processor.py
│ └── metrics.py
├── configs/ # Configuration files
├── train.py # Training script
├── evaluate.py # Evaluation script
└── inference.py # Inference scriptData Preprocessing
1. Text Preprocessor
python
import re
import string
import torch
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import jieba # Chinese word segmentation
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
class TextPreprocessor:
def __init__(self, language='en', max_vocab_size=50000, min_freq=2):
self.language = language
self.max_vocab_size = max_vocab_size
self.min_freq = min_freq
# Special tokens
self.PAD_TOKEN = '<PAD>'
self.UNK_TOKEN = '<UNK>'
self.SOS_TOKEN = '<SOS>'
self.EOS_TOKEN = '<EOS>'
# Vocabulary
self.vocab = {}
self.idx2word = {}
self.word_freq = Counter()
# Stopwords
if language == 'en':
try:
self.stop_words = set(stopwords.words('english'))
except:
self.stop_words = set()
else:
self.stop_words = set()
def clean_text(self, text: str) -> str:
"""Clean text"""
# Convert to lowercase
text = text.lower()
# Remove HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
# Remove email addresses
text = re.sub(r'\S+@\S+', '', text)
# Remove numbers (optional)
# text = re.sub(r'\d+', '', text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
def tokenize(self, text: str) -> List[str]:
"""Tokenize"""
if self.language == 'zh':
# Chinese word segmentation
tokens = list(jieba.cut(text))
else:
# English tokenization
tokens = word_tokenize(text)
# Remove punctuation and stopwords
tokens = [
token for token in tokens
if token not in string.punctuation and token not in self.stop_words
]
return tokens
def build_vocab(self, texts: List[str]):
"""Build vocabulary"""
print("Building vocabulary...")
# Count word frequencies
for text in texts:
cleaned_text = self.clean_text(text)
tokens = self.tokenize(cleaned_text)
self.word_freq.update(tokens)
# Create vocabulary list
vocab_list = [self.PAD_TOKEN, self.UNK_TOKEN, self.SOS_TOKEN, self.EOS_TOKEN]
# Sort by frequency, take top max_vocab_size words
sorted_words = sorted(self.word_freq.items(), key=lambda x: x[1], reverse=True)
for word, freq in sorted_words:
if freq >= self.min_freq and len(vocab_list) < self.max_vocab_size:
vocab_list.append(word)
# Build vocabulary mapping
self.vocab = {word: idx for idx, word in enumerate(vocab_list)}
self.idx2word = {idx: word for word, idx in self.vocab.items()}
print(f"Vocabulary size: {len(self.vocab)}")
print(f"Total word frequency: {sum(self.word_freq.values())}")
def text_to_sequence(self, text: str, max_length: int = None) -> List[int]:
"""Convert text to sequence"""
cleaned_text = self.clean_text(text)
tokens = self.tokenize(cleaned_text)
# Convert to indices
sequence = [
self.vocab.get(token, self.vocab[self.UNK_TOKEN])
for token in tokens
]
# Truncate or pad
if max_length:
if len(sequence) > max_length:
sequence = sequence[:max_length]
else:
sequence.extend([self.vocab[self.PAD_TOKEN]] * (max_length - len(sequence)))
return sequence
def sequence_to_text(self, sequence: List[int]) -> str:
"""Convert sequence to text"""
tokens = [
self.idx2word.get(idx, self.UNK_TOKEN)
for idx in sequence
if idx != self.vocab[self.PAD_TOKEN]
]
return ' '.join(tokens)
def save_vocab(self, filepath: str):
"""Save vocabulary"""
import pickle
vocab_data = {
'vocab': self.vocab,
'idx2word': self.idx2word,
'word_freq': self.word_freq,
'config': {
'language': self.language,
'max_vocab_size': self.max_vocab_size,
'min_freq': self.min_freq
}
}
with open(filepath, 'wb') as f:
pickle.dump(vocab_data, f)
def load_vocab(self, filepath: str):
"""Load vocabulary"""
import pickle
with open(filepath, 'rb') as f:
vocab_data = pickle.load(f)
self.vocab = vocab_data['vocab']
self.idx2word = vocab_data['idx2word']
self.word_freq = vocab_data['word_freq']
config = vocab_data['config']
self.language = config['language']
self.max_vocab_size = config['max_vocab_size']
self.min_freq = config['min_freq']
# Usage example
preprocessor = TextPreprocessor(language='en')
# Example texts
texts = [
"I love this movie! It's absolutely fantastic.",
"This film is terrible. I hate it.",
"The movie is okay, not great but not bad either."
]
# Build vocabulary
preprocessor.build_vocab(texts)
# Convert text to sequence
sequence = preprocessor.text_to_sequence(texts[0], max_length=20)
print(f"Original text: {texts[0]}")
print(f"Sequence: {sequence}")
print(f"Recovered: {preprocessor.sequence_to_text(sequence)}")2. Dataset Class
python
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, preprocessor, max_length=128):
self.texts = texts
self.labels = labels
self.preprocessor = preprocessor
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# Convert text to sequence
sequence = self.preprocessor.text_to_sequence(text, self.max_length)
return {
'input_ids': torch.tensor(sequence, dtype=torch.long),
'label': torch.tensor(label, dtype=torch.long),
'text': text
}
def create_data_loaders(train_texts, train_labels, val_texts, val_labels,
preprocessor, batch_size=32, max_length=128):
"""Create data loaders"""
train_dataset = TextClassificationDataset(
train_texts, train_labels, preprocessor, max_length
)
val_dataset = TextClassificationDataset(
val_texts, val_labels, preprocessor, max_length
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True
)
return train_loader, val_loader
# Example data loading
def load_imdb_data():
"""Load IMDB dataset example"""
# Using example data here, real data needs to be loaded in actual projects
train_texts = [
"I love this movie! It's absolutely fantastic.",
"This film is terrible. I hate it.",
"The movie is okay, not great but not bad either.",
"Amazing cinematography and great acting!",
"Boring and predictable plot."
] * 1000 # Expand data
train_labels = [1, 0, 2, 1, 0] * 1000 # 0: negative, 1: positive, 2: neutral
val_texts = train_texts[:500]
val_labels = train_labels[:500]
return train_texts, train_labels, val_texts, val_labelsModel Architecture
1. LSTM Classifier
python
import torch.nn as nn
import torch.nn.functional as F
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes,
num_layers=2, dropout=0.3, bidirectional=True):
super(LSTMClassifier, self).__init__()
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
# Word embedding layer
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM layer
self.lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# Attention mechanism
lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
self.attention = nn.Linear(lstm_output_dim, 1)
# Classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(lstm_output_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, input_ids, attention_mask=None):
# Word embedding
embedded = self.embedding(input_ids) # (batch_size, seq_len, embed_dim)
# LSTM
lstm_out, (hidden, cell) = self.lstm(embedded) # (batch_size, seq_len, hidden_dim*2)
# Attention mechanism
if attention_mask is not None:
# Create attention mask
attention_weights = self.attention(lstm_out).squeeze(-1) # (batch_size, seq_len)
attention_weights = attention_weights.masked_fill(attention_mask == 0, -1e9)
attention_weights = F.softmax(attention_weights, dim=1)
# Weighted average
context = torch.sum(attention_weights.unsqueeze(-1) * lstm_out, dim=1)
else:
# Simple average pooling
context = torch.mean(lstm_out, dim=1)
# Classification
logits = self.classifier(context)
return logits
# Create model
def create_lstm_model(vocab_size, num_classes):
model = LSTMClassifier(
vocab_size=vocab_size,
embed_dim=128,
hidden_dim=256,
num_classes=num_classes,
num_layers=2,
dropout=0.3,
bidirectional=True
)
return model2. CNN Classifier
python
class CNNClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes,
num_classes, dropout=0.3):
super(CNNClassifier, self).__init__()
# Word embedding layer
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# Multiple convolutional layers
self.convs = nn.ModuleList([
nn.Conv1d(embed_dim, num_filters, kernel_size=fs)
for fs in filter_sizes
])
# Classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(len(filter_sizes) * num_filters, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes)
)
def forward(self, input_ids):
# Word embedding
embedded = self.embedding(input_ids) # (batch_size, seq_len, embed_dim)
embedded = embedded.transpose(1, 2) # (batch_size, embed_dim, seq_len)
# Convolution and pooling
conv_outputs = []
for conv in self.convs:
conv_out = F.relu(conv(embedded)) # (batch_size, num_filters, conv_seq_len)
pooled = F.max_pool1d(conv_out, conv_out.size(2)) # (batch_size, num_filters, 1)
conv_outputs.append(pooled.squeeze(2))
# Concatenate all convolution outputs
concat_output = torch.cat(conv_outputs, dim=1) # (batch_size, len(filter_sizes) * num_filters)
# Classification
logits = self.classifier(concat_output)
return logits
def create_cnn_model(vocab_size, num_classes):
model = CNNClassifier(
vocab_size=vocab_size,
embed_dim=128,
num_filters=100,
filter_sizes=[3, 4, 5],
num_classes=num_classes,
dropout=0.3
)
return model3. Transformer Classifier
python
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, num_layers,
num_classes, max_length=512, dropout=0.1):
super(TransformerClassifier, self).__init__()
self.embed_dim = embed_dim
self.max_length = max_length
# Word embedding and positional encoding
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.pos_encoding = nn.Parameter(torch.randn(max_length, embed_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# Classifier
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(embed_dim // 2, num_classes)
)
def forward(self, input_ids, attention_mask=None):
seq_len = input_ids.size(1)
# Word embedding + positional encoding
embedded = self.embedding(input_ids)
embedded += self.pos_encoding[:seq_len, :].unsqueeze(0)
# Create padding mask
if attention_mask is None:
attention_mask = (input_ids != 0)
# Transformer encoding
transformer_out = self.transformer(
embedded,
src_key_padding_mask=~attention_mask
)
# Global average pooling (ignore padding positions)
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_embeddings = torch.sum(transformer_out * mask_expanded, dim=1)
sum_mask = torch.sum(mask_expanded, dim=1)
pooled = sum_embeddings / sum_mask
# Classification
logits = self.classifier(pooled)
return logits
def create_transformer_model(vocab_size, num_classes):
model = TransformerClassifier(
vocab_size=vocab_size,
embed_dim=256,
num_heads=8,
num_layers=6,
num_classes=num_classes,
max_length=512,
dropout=0.1
)
return modelTraining Framework
1. Trainer
python
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
class TextClassificationTrainer:
def __init__(self, model, train_loader, val_loader, num_classes, device):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.num_classes = num_classes
self.device = device
# Loss function and optimizer
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', patience=3, factor=0.5)
# Training history
self.train_losses = []
self.train_accs = []
self.val_losses = []
self.val_accs = []
# Best model
self.best_val_acc = 0.0
def train_epoch(self):
"""Train one epoch"""
self.model.train()
total_loss = 0
all_preds = []
all_labels = []
for batch in self.train_loader:
input_ids = batch['input_ids'].to(self.device)
labels = batch['label'].to(self.device)
# Create attention mask
attention_mask = (input_ids != 0)
self.optimizer.zero_grad()
# Forward pass
if isinstance(self.model, TransformerClassifier):
logits = self.model(input_ids, attention_mask)
else:
logits = self.model(input_ids)
loss = self.criterion(logits, labels)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Statistics
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(self.train_loader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy
def validate_epoch(self):
"""Validate one epoch"""
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in self.val_loader:
input_ids = batch['input_ids'].to(self.device)
labels = batch['label'].to(self.device)
attention_mask = (input_ids != 0)
if isinstance(self.model, TransformerClassifier):
logits = self.model(input_ids, attention_mask)
else:
logits = self.model(input_ids)
loss = self.criterion(logits, labels)
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(self.val_loader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy, all_preds, all_labels
def train(self, num_epochs):
"""Complete training process"""
print(f"Starting training, {num_epochs} epochs")
for epoch in range(num_epochs):
# Train
train_loss, train_acc = self.train_epoch()
# Validate
val_loss, val_acc, val_preds, val_labels = self.validate_epoch()
# Update learning rate
self.scheduler.step(val_acc)
# Record history
self.train_losses.append(train_loss)
self.train_accs.append(train_acc)
self.val_losses.append(val_loss)
self.val_accs.append(val_acc)
# Print results
print(f'Epoch {epoch+1}/{num_epochs}:')
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
print(f' LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
# Save best model
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
torch.save(self.model.state_dict(), 'best_model.pth')
print(f' ✓ New best model! Validation accuracy: {val_acc:.4f}')
print('-' * 50)
print(f'Training complete! Best validation accuracy: {self.best_val_acc:.4f}')
return self.train_losses, self.train_accs, self.val_losses, self.val_accsModel Evaluation
1. Detailed Evaluation
python
def evaluate_model(model, test_loader, device, class_names):
"""Evaluate model in detail"""
model.eval()
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
attention_mask = (input_ids != 0)
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
report = classification_report(all_labels, all_preds, target_names=class_names)
cm = confusion_matrix(all_labels, all_preds)
print(f"Test accuracy: {accuracy:.4f}")
print("\nClassification report:")
print(report)
return accuracy, report, cm, all_probs2. Error Analysis
python
def analyze_errors(model, test_loader, preprocessor, device, class_names, num_examples=10):
"""Analyze incorrect predictions"""
model.eval()
errors = []
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
texts = batch['text']
attention_mask = (input_ids != 0)
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
# Find incorrect predictions
wrong_mask = preds != labels
if wrong_mask.any():
wrong_indices = torch.where(wrong_mask)[0]
for idx in wrong_indices:
errors.append({
'text': texts[idx],
'true_label': class_names[labels[idx].item()],
'pred_label': class_names[preds[idx].item()],
'confidence': probs[idx].max().item(),
'all_probs': probs[idx].cpu().numpy()
})
if len(errors) >= num_examples:
break
if len(errors) >= num_examples:
break
# Print error analysis
print("Incorrect Prediction Analysis:")
print("=" * 80)
for i, error in enumerate(errors):
print(f"\nExample {i+1}:")
print(f"Text: {error['text'][:200]}...")
print(f"True label: {error['true_label']}")
print(f"Predicted label: {error['pred_label']}")
print(f"Confidence: {error['confidence']:.4f}")
# Show probabilities for all classes
for j, prob in enumerate(error['all_probs']):
print(f" {class_names[j]}: {prob:.4f}")
return errorsInference and Application
1. Single Text Inference
python
def predict_single_text(model, text, preprocessor, device, class_names, max_length=128):
"""Predict single text"""
model.eval()
# Preprocess
sequence = preprocessor.text_to_sequence(text, max_length)
input_ids = torch.tensor([sequence], dtype=torch.long).to(device)
attention_mask = (input_ids != 0)
with torch.no_grad():
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
pred_class = torch.argmax(logits, dim=1).item()
confidence = probs[0][pred_class].item()
# Get probabilities for all classes
all_probs = probs[0].cpu().numpy()
result = {
'text': text,
'predicted_class': class_names[pred_class],
'confidence': confidence,
'all_probabilities': {
class_names[i]: float(prob) for i, prob in enumerate(all_probs)
}
}
return result
# Usage example
text = "I absolutely love this movie! It's fantastic!"
result = predict_single_text(model, text, preprocessor, device, ['negative', 'positive', 'neutral'])
print(f"Prediction result: {result}")2. Batch Inference
python
def batch_predict(model, texts, preprocessor, device, class_names, batch_size=32, max_length=128):
"""Batch prediction"""
model.eval()
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# Preprocess batch data
sequences = [preprocessor.text_to_sequence(text, max_length) for text in batch_texts]
input_ids = torch.tensor(sequences, dtype=torch.long).to(device)
attention_mask = (input_ids != 0)
with torch.no_grad():
if isinstance(model, TransformerClassifier):
logits = model(input_ids, attention_mask)
else:
logits = model(input_ids)
probs = F.softmax(logits, dim=1)
pred_classes = torch.argmax(logits, dim=1)
# Process results
for j, text in enumerate(batch_texts):
pred_class = pred_classes[j].item()
confidence = probs[j][pred_class].item()
results.append({
'text': text,
'predicted_class': class_names[pred_class],
'confidence': confidence
})
return resultsComplete Training Script
python
def main():
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load data
train_texts, train_labels, val_texts, val_labels = load_imdb_data()
# Create preprocessor
preprocessor = TextPreprocessor(language='en')
preprocessor.build_vocab(train_texts)
# Create data loaders
train_loader, val_loader = create_data_loaders(
train_texts, train_labels, val_texts, val_labels,
preprocessor, batch_size=32, max_length=128
)
# Create model
vocab_size = len(preprocessor.vocab)
num_classes = 3 # negative, positive, neutral
# Select model type
model = create_lstm_model(vocab_size, num_classes)
# model = create_cnn_model(vocab_size, num_classes)
# model = create_transformer_model(vocab_size, num_classes)
print(f"Model parameter count: {sum(p.numel() for p in model.parameters()):,}")
# Create trainer
trainer = TextClassificationTrainer(model, train_loader, val_loader, num_classes, device)
# Train model
train_losses, train_accs, val_losses, val_accs = trainer.train(num_epochs=20)
# Load best model
model.load_state_dict(torch.load('best_model.pth'))
# Evaluate model
class_names = ['negative', 'positive', 'neutral']
accuracy, report, cm, probs = evaluate_model(model, val_loader, device, class_names)
# Error analysis
errors = analyze_errors(model, val_loader, preprocessor, device, class_names)
# Save model and preprocessor
torch.save(model.state_dict(), 'final_model.pth')
preprocessor.save_vocab('vocab.pkl')
print("Training complete!")
if __name__ == '__main__':
main()Summary
This chapter demonstrated through a complete text classification project:
- Text Preprocessing: Complete workflow of cleaning, tokenization, vocabulary building
- Model Architectures: Different text classification models including LSTM, CNN, Transformer
- Training Framework: Complete training, validation, and saving process
- Model Evaluation: Evaluation methods including accuracy, classification report, error analysis
- Practical Applications: Implementation of single text and batch inference
This project template can be applied to various text classification tasks such as sentiment analysis, topic classification, spam detection, etc.