PyTorch Tensor Basics
What is a Tensor?
Tensors are the most fundamental data structure in PyTorch, and can be understood as a generalization of multi-dimensional arrays:
- 0D Tensor: Scalar
- 1D Tensor: Vector
- 2D Tensor: Matrix
- 3D and above: Higher-dimensional tensors
python
import torch
# Tensors of different dimensions
scalar = torch.tensor(3.14) # 0D
vector = torch.tensor([1, 2, 3]) # 1D
matrix = torch.tensor([[1, 2], [3, 4]]) # 2D
tensor_3d = torch.zeros(2, 3, 4) # 3DCreating Tensors
1. Creating from Data
python
import torch
import numpy as np
# Create from Python list
data = [[1, 2], [3, 4]]
x = torch.tensor(data)
print(f"Created from list: {x}")
# Create from NumPy array
np_array = np.array([[1, 2], [3, 4]])
x = torch.from_numpy(np_array)
print(f"Created from NumPy: {x}")
# Specify data type
x = torch.tensor([1.0, 2.0], dtype=torch.float32)
print(f"Specified type: {x.dtype}")2. Creating with Built-in Functions
python
# Zero tensor
zeros = torch.zeros(3, 4)
print(f"Zero tensor:\n{zeros}")
# One tensor
ones = torch.ones(2, 3)
print(f"One tensor:\n{ones}")
# Identity matrix
eye = torch.eye(3)
print(f"Identity matrix:\n{eye}")
# Random tensors
rand = torch.rand(2, 3) # Uniform distribution [0, 1)
randn = torch.randn(2, 3) # Standard normal distribution
print(f"Random tensor:\n{rand}")
print(f"Normal distribution:\n{randn}")
# Random integers in specified range
randint = torch.randint(0, 10, (3, 3))
print(f"Random integers:\n{randint}")
# Arithmetic sequence
arange = torch.arange(0, 10, 2) # start, end, step
linspace = torch.linspace(0, 1, 5) # start, end, steps
print(f"Arithmetic sequence: {arange}")
print(f"Linear space: {linspace}")3. Creating from Other Tensors
python
x = torch.tensor([[1, 2], [3, 4]])
# Create tensors with the same shape
zeros_like = torch.zeros_like(x)
ones_like = torch.ones_like(x)
rand_like = torch.rand_like(x, dtype=torch.float)
print(f"Original tensor:\n{x}")
print(f"Zero tensor with same shape:\n{zeros_like}")Tensor Properties
python
x = torch.randn(3, 4, 5)
print(f"Shape: {x.shape}") # or x.size()
print(f"Dimensions: {x.ndim}") # or x.dim()
print(f"Total elements: {x.numel()}")
print(f"Data type: {x.dtype}")
print(f"Device: {x.device}")
print(f"Requires gradient: {x.requires_grad}")Data Types
PyTorch supports various data types:
python
# Integer types
int8 = torch.tensor([1, 2, 3], dtype=torch.int8)
int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
int32 = torch.tensor([1, 2, 3], dtype=torch.int32)
int64 = torch.tensor([1, 2, 3], dtype=torch.int64)
# Float types
float16 = torch.tensor([1.0, 2.0], dtype=torch.float16) # Half precision
float32 = torch.tensor([1.0, 2.0], dtype=torch.float32) # Single precision
float64 = torch.tensor([1.0, 2.0], dtype=torch.float64) # Double precision
# Boolean type
bool_tensor = torch.tensor([True, False], dtype=torch.bool)
# Type conversion
x = torch.tensor([1, 2, 3])
x_float = x.float() # Convert to float32
x_double = x.double() # Convert to float64
x_int = x_float.int() # Convert to int32
print(f"Original type: {x.dtype}")
print(f"After conversion: {x_float.dtype}")Tensor Operations
1. Indexing and Slicing
python
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# Basic indexing
print(f"First row: {x[0]}")
print(f"First column: {x[:, 0]}")
print(f"Specific element: {x[1, 2]}")
# Slicing
print(f"First two rows: {x[:2]}")
print(f"Last two columns: {x[:, -2:]}")
print(f"Sub-matrix: {x[1:3, 1:3]}")
# Boolean indexing
mask = x > 6
print(f"Elements greater than 6: {x[mask]}")
# Advanced indexing
indices = torch.tensor([0, 2])
print(f"Select specific rows: {x[indices]}")2. Shape Transformation
python
x = torch.randn(2, 3, 4)
# Reshape
reshaped = x.view(6, 4) # view requires contiguous memory
reshaped2 = x.reshape(8, 3) # reshape is more flexible
# Flatten
flattened = x.flatten() # Flatten to 1D
flattened_partial = x.flatten(start_dim=1) # Flatten from dimension 1
# Add dimensions
unsqueezed = x.unsqueeze(0) # Add dimension at position 0
unsqueezed2 = x.unsqueeze(-1) # Add dimension at the end
# Remove dimensions
squeezed = unsqueezed.squeeze(0) # Remove dimension 0
# Transpose
transposed = x.transpose(0, 1) # Swap dimensions 0 and 1
permuted = x.permute(2, 0, 1) # Rearrange dimensions
print(f"Original shape: {x.shape}")
print(f"After reshape: {reshaped.shape}")
print(f"After transpose: {transposed.shape}")3. Concatenation and Splitting
python
x = torch.randn(2, 3)
y = torch.randn(2, 3)
# Concatenate
cat_dim0 = torch.cat([x, y], dim=0) # Concatenate along dimension 0
cat_dim1 = torch.cat([x, y], dim=1) # Concatenate along dimension 1
# Stack (adds new dimension)
stacked = torch.stack([x, y], dim=0)
# Split
chunks = torch.chunk(cat_dim0, 2, dim=0) # Split into 2 chunks
splits = torch.split(cat_dim1, 3, dim=1) # Split by specified size
print(f"Concatenation result shape: {cat_dim0.shape}")
print(f"Stack result shape: {stacked.shape}")
print(f"Number of split chunks: {len(chunks)}")Mathematical Operations
1. Basic Operations
python
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float)
# Element-wise operations
add = x + y # or torch.add(x, y)
sub = x - y # or torch.sub(x, y)
mul = x * y # or torch.mul(x, y)
div = x / y # or torch.div(x, y)
# In-place operations (modify original tensor)
x.add_(y) # x = x + y
print(f"After in-place operation: {x}")
# Scalar operations
scaled = x * 2
shifted = x + 1
print(f"Addition: {add}")
print(f"Multiplication: {mul}")2. Matrix Operations
python
x = torch.randn(3, 4)
y = torch.randn(4, 5)
# Matrix multiplication
matmul = torch.matmul(x, y) # or x @ y
mm = torch.mm(x, y) # 2D matrix multiplication
# Batch matrix multiplication
batch_x = torch.randn(10, 3, 4)
batch_y = torch.randn(10, 4, 5)
batch_result = torch.bmm(batch_x, batch_y)
print(f"Matrix multiplication result shape: {matmul.shape}")
print(f"Batch matrix multiplication result shape: {batch_result.shape}")3. Statistical Operations
python
x = torch.randn(3, 4)
# Basic statistics
mean = x.mean() # Global mean
mean_dim = x.mean(dim=0) # Mean along dimension 0
std = x.std() # Standard deviation
var = x.var() # Variance
# Min/max values
max_val = x.max()
min_val = x.min()
max_indices = x.argmax() # Index of maximum value
min_indices = x.argmin() # Index of minimum value
# Sum
sum_all = x.sum()
sum_dim = x.sum(dim=1) # Sum along dimension 1
print(f"Mean: {mean:.4f}")
print(f"Mean by dimension: {mean_dim}")
print(f"Maximum: {max_val:.4f}")Broadcasting
PyTorch supports broadcasting, allowing tensors of different shapes to operate together:
python
# Broadcasting rules:
# 1. Compare dimensions from the right
# 2. Dimensions can broadcast if they are equal or one is 1
# 3. Missing dimensions are treated as 1
x = torch.randn(3, 4)
y = torch.randn(4) # Will broadcast to (1, 4)
z = torch.randn(3, 1) # Will broadcast to (3, 4)
result1 = x + y # (3, 4) + (4,) -> (3, 4)
result2 = x + z # (3, 4) + (3, 1) -> (3, 4)
print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")
print(f"Broadcasting result shape: {result1.shape}")
# Manual broadcasting
y_expanded = y.expand_as(x) # Expand to same shape as x
print(f"After manual expansion: {y_expanded.shape}")Device Management
python
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Create tensor on specified device
x = torch.randn(3, 4, device=device)
# Move tensor to different device
x_cpu = torch.randn(3, 4)
x_gpu = x_cpu.to(device) # Move to GPU
x_back = x_gpu.cpu() # Move back to CPU
print(f"CPU tensor device: {x_cpu.device}")
print(f"GPU tensor device: {x_gpu.device}")Memory Management
python
# Check if tensors share memory
x = torch.randn(3, 4)
y = x.view(4, 3) # view shares memory
z = x.clone() # clone creates a copy
print(f"x and y share memory: {x.storage().data_ptr() == y.storage().data_ptr()}")
print(f"x and z share memory: {x.storage().data_ptr() == z.storage().data_ptr()}")
# Detach tensor (for gradient computation)
x = torch.randn(3, 4, requires_grad=True)
y = x.detach() # Detached, won't participate in gradient computation
print(f"Original tensor requires grad: {x.requires_grad}")
print(f"After detach requires grad: {y.requires_grad}")Practical Tips
1. Tensor Initialization Tips
python
# Xavier initialization
def xavier_init(tensor):
nn.init.xavier_uniform_(tensor)
# He initialization
def he_init(tensor):
nn.init.kaiming_uniform_(tensor)
# Custom initialization
def custom_init(tensor):
with torch.no_grad():
tensor.uniform_(-0.1, 0.1)2. Performance Optimization
python
# Use appropriate data types
x = torch.randn(1000, 1000, dtype=torch.float32) # instead of float64
# Avoid frequent CPU-GPU transfers
x = torch.randn(1000, 1000, device='cuda')
# Perform all computations on GPU
result = x.mm(x.t()).sum()
# Use in-place operations to save memory
x.add_(1) # instead of x = x + 1Summary
Tensors are the core of PyTorch, and mastering tensor operations is fundamental to deep learning:
- Creating Tensors: Understand various creation methods and data types
- Shape Operations: Proficiently use view, reshape, transpose, etc.
- Mathematical Operations: Master basic operations and matrix operations
- Broadcasting: Understand operation rules for tensors of different shapes
- Device Management: Reasonably use CPU and GPU resources
These basic operations will be frequently used in subsequent neural network construction, so practice is recommended!