Skip to content

PyTorch Tensor Basics

What is a Tensor?

Tensors are the most fundamental data structure in PyTorch, and can be understood as a generalization of multi-dimensional arrays:

  • 0D Tensor: Scalar
  • 1D Tensor: Vector
  • 2D Tensor: Matrix
  • 3D and above: Higher-dimensional tensors
python
import torch

# Tensors of different dimensions
scalar = torch.tensor(3.14)          # 0D
vector = torch.tensor([1, 2, 3])     # 1D
matrix = torch.tensor([[1, 2], [3, 4]])  # 2D
tensor_3d = torch.zeros(2, 3, 4)     # 3D

Creating Tensors

1. Creating from Data

python
import torch
import numpy as np

# Create from Python list
data = [[1, 2], [3, 4]]
x = torch.tensor(data)
print(f"Created from list: {x}")

# Create from NumPy array
np_array = np.array([[1, 2], [3, 4]])
x = torch.from_numpy(np_array)
print(f"Created from NumPy: {x}")

# Specify data type
x = torch.tensor([1.0, 2.0], dtype=torch.float32)
print(f"Specified type: {x.dtype}")

2. Creating with Built-in Functions

python
# Zero tensor
zeros = torch.zeros(3, 4)
print(f"Zero tensor:\n{zeros}")

# One tensor
ones = torch.ones(2, 3)
print(f"One tensor:\n{ones}")

# Identity matrix
eye = torch.eye(3)
print(f"Identity matrix:\n{eye}")

# Random tensors
rand = torch.rand(2, 3)  # Uniform distribution [0, 1)
randn = torch.randn(2, 3)  # Standard normal distribution
print(f"Random tensor:\n{rand}")
print(f"Normal distribution:\n{randn}")

# Random integers in specified range
randint = torch.randint(0, 10, (3, 3))
print(f"Random integers:\n{randint}")

# Arithmetic sequence
arange = torch.arange(0, 10, 2)  # start, end, step
linspace = torch.linspace(0, 1, 5)  # start, end, steps
print(f"Arithmetic sequence: {arange}")
print(f"Linear space: {linspace}")

3. Creating from Other Tensors

python
x = torch.tensor([[1, 2], [3, 4]])

# Create tensors with the same shape
zeros_like = torch.zeros_like(x)
ones_like = torch.ones_like(x)
rand_like = torch.rand_like(x, dtype=torch.float)

print(f"Original tensor:\n{x}")
print(f"Zero tensor with same shape:\n{zeros_like}")

Tensor Properties

python
x = torch.randn(3, 4, 5)

print(f"Shape: {x.shape}")        # or x.size()
print(f"Dimensions: {x.ndim}")         # or x.dim()
print(f"Total elements: {x.numel()}")
print(f"Data type: {x.dtype}")
print(f"Device: {x.device}")
print(f"Requires gradient: {x.requires_grad}")

Data Types

PyTorch supports various data types:

python
# Integer types
int8 = torch.tensor([1, 2, 3], dtype=torch.int8)
int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
int32 = torch.tensor([1, 2, 3], dtype=torch.int32)
int64 = torch.tensor([1, 2, 3], dtype=torch.int64)

# Float types
float16 = torch.tensor([1.0, 2.0], dtype=torch.float16)  # Half precision
float32 = torch.tensor([1.0, 2.0], dtype=torch.float32)  # Single precision
float64 = torch.tensor([1.0, 2.0], dtype=torch.float64)  # Double precision

# Boolean type
bool_tensor = torch.tensor([True, False], dtype=torch.bool)

# Type conversion
x = torch.tensor([1, 2, 3])
x_float = x.float()  # Convert to float32
x_double = x.double()  # Convert to float64
x_int = x_float.int()  # Convert to int32

print(f"Original type: {x.dtype}")
print(f"After conversion: {x_float.dtype}")

Tensor Operations

1. Indexing and Slicing

python
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# Basic indexing
print(f"First row: {x[0]}")
print(f"First column: {x[:, 0]}")
print(f"Specific element: {x[1, 2]}")

# Slicing
print(f"First two rows: {x[:2]}")
print(f"Last two columns: {x[:, -2:]}")
print(f"Sub-matrix: {x[1:3, 1:3]}")

# Boolean indexing
mask = x > 6
print(f"Elements greater than 6: {x[mask]}")

# Advanced indexing
indices = torch.tensor([0, 2])
print(f"Select specific rows: {x[indices]}")

2. Shape Transformation

python
x = torch.randn(2, 3, 4)

# Reshape
reshaped = x.view(6, 4)  # view requires contiguous memory
reshaped2 = x.reshape(8, 3)  # reshape is more flexible

# Flatten
flattened = x.flatten()  # Flatten to 1D
flattened_partial = x.flatten(start_dim=1)  # Flatten from dimension 1

# Add dimensions
unsqueezed = x.unsqueeze(0)  # Add dimension at position 0
unsqueezed2 = x.unsqueeze(-1)  # Add dimension at the end

# Remove dimensions
squeezed = unsqueezed.squeeze(0)  # Remove dimension 0

# Transpose
transposed = x.transpose(0, 1)  # Swap dimensions 0 and 1
permuted = x.permute(2, 0, 1)  # Rearrange dimensions

print(f"Original shape: {x.shape}")
print(f"After reshape: {reshaped.shape}")
print(f"After transpose: {transposed.shape}")

3. Concatenation and Splitting

python
x = torch.randn(2, 3)
y = torch.randn(2, 3)

# Concatenate
cat_dim0 = torch.cat([x, y], dim=0)  # Concatenate along dimension 0
cat_dim1 = torch.cat([x, y], dim=1)  # Concatenate along dimension 1

# Stack (adds new dimension)
stacked = torch.stack([x, y], dim=0)

# Split
chunks = torch.chunk(cat_dim0, 2, dim=0)  # Split into 2 chunks
splits = torch.split(cat_dim1, 3, dim=1)  # Split by specified size

print(f"Concatenation result shape: {cat_dim0.shape}")
print(f"Stack result shape: {stacked.shape}")
print(f"Number of split chunks: {len(chunks)}")

Mathematical Operations

1. Basic Operations

python
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float)

# Element-wise operations
add = x + y  # or torch.add(x, y)
sub = x - y  # or torch.sub(x, y)
mul = x * y  # or torch.mul(x, y)
div = x / y  # or torch.div(x, y)

# In-place operations (modify original tensor)
x.add_(y)  # x = x + y
print(f"After in-place operation: {x}")

# Scalar operations
scaled = x * 2
shifted = x + 1

print(f"Addition: {add}")
print(f"Multiplication: {mul}")

2. Matrix Operations

python
x = torch.randn(3, 4)
y = torch.randn(4, 5)

# Matrix multiplication
matmul = torch.matmul(x, y)  # or x @ y
mm = torch.mm(x, y)  # 2D matrix multiplication

# Batch matrix multiplication
batch_x = torch.randn(10, 3, 4)
batch_y = torch.randn(10, 4, 5)
batch_result = torch.bmm(batch_x, batch_y)

print(f"Matrix multiplication result shape: {matmul.shape}")
print(f"Batch matrix multiplication result shape: {batch_result.shape}")

3. Statistical Operations

python
x = torch.randn(3, 4)

# Basic statistics
mean = x.mean()  # Global mean
mean_dim = x.mean(dim=0)  # Mean along dimension 0
std = x.std()  # Standard deviation
var = x.var()  # Variance

# Min/max values
max_val = x.max()
min_val = x.min()
max_indices = x.argmax()  # Index of maximum value
min_indices = x.argmin()  # Index of minimum value

# Sum
sum_all = x.sum()
sum_dim = x.sum(dim=1)  # Sum along dimension 1

print(f"Mean: {mean:.4f}")
print(f"Mean by dimension: {mean_dim}")
print(f"Maximum: {max_val:.4f}")

Broadcasting

PyTorch supports broadcasting, allowing tensors of different shapes to operate together:

python
# Broadcasting rules:
# 1. Compare dimensions from the right
# 2. Dimensions can broadcast if they are equal or one is 1
# 3. Missing dimensions are treated as 1

x = torch.randn(3, 4)
y = torch.randn(4)      # Will broadcast to (1, 4)
z = torch.randn(3, 1)   # Will broadcast to (3, 4)

result1 = x + y  # (3, 4) + (4,) -> (3, 4)
result2 = x + z  # (3, 4) + (3, 1) -> (3, 4)

print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")
print(f"Broadcasting result shape: {result1.shape}")

# Manual broadcasting
y_expanded = y.expand_as(x)  # Expand to same shape as x
print(f"After manual expansion: {y_expanded.shape}")

Device Management

python
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create tensor on specified device
x = torch.randn(3, 4, device=device)

# Move tensor to different device
x_cpu = torch.randn(3, 4)
x_gpu = x_cpu.to(device)  # Move to GPU
x_back = x_gpu.cpu()      # Move back to CPU

print(f"CPU tensor device: {x_cpu.device}")
print(f"GPU tensor device: {x_gpu.device}")

Memory Management

python
# Check if tensors share memory
x = torch.randn(3, 4)
y = x.view(4, 3)  # view shares memory
z = x.clone()     # clone creates a copy

print(f"x and y share memory: {x.storage().data_ptr() == y.storage().data_ptr()}")
print(f"x and z share memory: {x.storage().data_ptr() == z.storage().data_ptr()}")

# Detach tensor (for gradient computation)
x = torch.randn(3, 4, requires_grad=True)
y = x.detach()  # Detached, won't participate in gradient computation

print(f"Original tensor requires grad: {x.requires_grad}")
print(f"After detach requires grad: {y.requires_grad}")

Practical Tips

1. Tensor Initialization Tips

python
# Xavier initialization
def xavier_init(tensor):
    nn.init.xavier_uniform_(tensor)

# He initialization
def he_init(tensor):
    nn.init.kaiming_uniform_(tensor)

# Custom initialization
def custom_init(tensor):
    with torch.no_grad():
        tensor.uniform_(-0.1, 0.1)

2. Performance Optimization

python
# Use appropriate data types
x = torch.randn(1000, 1000, dtype=torch.float32)  # instead of float64

# Avoid frequent CPU-GPU transfers
x = torch.randn(1000, 1000, device='cuda')
# Perform all computations on GPU
result = x.mm(x.t()).sum()

# Use in-place operations to save memory
x.add_(1)  # instead of x = x + 1

Summary

Tensors are the core of PyTorch, and mastering tensor operations is fundamental to deep learning:

  1. Creating Tensors: Understand various creation methods and data types
  2. Shape Operations: Proficiently use view, reshape, transpose, etc.
  3. Mathematical Operations: Master basic operations and matrix operations
  4. Broadcasting: Understand operation rules for tensors of different shapes
  5. Device Management: Reasonably use CPU and GPU resources

These basic operations will be frequently used in subsequent neural network construction, so practice is recommended!

Content is for learning and research only.