TensorFlow Data Processing
tf.data API Introduction
The tf.data API is a core tool in TensorFlow for building efficient data input pipelines. It provides a powerful set of tools to load, transform, and batch data, making it particularly suitable for processing large-scale datasets.
python
import tensorflow as tf
import numpy as np
import pandas as pd
from pathlib import Path
print(f"TensorFlow version: {tf.__version__}")Creating Datasets
1. Create from Memory Data
python
# Create dataset from tensors
data = tf.constant([1, 2, 3, 4, 5])
dataset = tf.data.Dataset.from_tensor_slices(data)
print("Dataset created from tensors:")
for element in dataset:
print(element.numpy())
# Create from multiple tensors
features = tf.constant([[1, 2], [3, 4], [5, 6]])
labels = tf.constant([0, 1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
print("\nFeature-label dataset:")
for feature, label in dataset:
print(f"Feature: {feature.numpy()}, Label: {label.numpy()}")
# Create from dictionaries
dataset_dict = tf.data.Dataset.from_tensor_slices({
'features': [[1, 2], [3, 4], [5, 6]],
'labels': [0, 1, 0]
})
print("\nDictionary format dataset:")
for element in dataset_dict:
print(f"Feature: {element['features'].numpy()}, Label: {element['labels'].numpy()}")2. Create from Generators
python
def data_generator():
"""Data generator function"""
for i in range(10):
yield i, i**2
# Create dataset from generator
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(), dtype=tf.int32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
print("Dataset created from generator:")
for x, y in dataset.take(5):
print(f"x: {x.numpy()}, y: {y.numpy()}")
# More complex generator
def complex_generator():
"""Complex data generator"""
for i in range(100):
# Simulate complex data generation logic
features = np.random.randn(10).astype(np.float32)
label = np.random.randint(0, 3)
yield features, label
complex_dataset = tf.data.Dataset.from_generator(
complex_generator,
output_signature=(
tf.TensorSpec(shape=(10,), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)3. Create from Files
python
# Create from text files
def create_text_dataset():
# Create sample text file
text_data = ["line 1", "line 2", "line 3", "line 4", "line 5"]
with open('sample.txt', 'w') as f:
for line in text_data:
f.write(line + '\n')
# Create dataset from text file
text_dataset = tf.data.TextLineDataset('sample.txt')
print("Text file dataset:")
for line in text_dataset:
print(line.numpy().decode('utf-8'))
return text_dataset
# create_text_dataset()
# Create from CSV files
def create_csv_dataset():
# Create sample CSV data
csv_data = pd.DataFrame({
'feature1': np.random.randn(100),
'feature2': np.random.randn(100),
'label': np.random.randint(0, 3, 100)
})
csv_data.to_csv('sample.csv', index=False)
# Create dataset from CSV
csv_dataset = tf.data.experimental.make_csv_dataset(
'sample.csv',
batch_size=5,
label_name='label',
num_epochs=1,
shuffle=False
)
print("CSV dataset:")
for batch in csv_dataset.take(2):
print("Features:", {k: v.numpy() for k, v in batch[0].items()})
print("Labels:", batch[1].numpy())
print()
return csv_dataset
# create_csv_dataset()4. Create from TFRecord Files
python
def create_tfrecord_dataset():
"""Create and read TFRecord dataset"""
# Create TFRecord file
def create_example(feature, label):
"""Create tf.train.Example"""
feature_dict = {
'feature': tf.train.Feature(
float_list=tf.train.FloatList(value=feature)
),
'label': tf.train.Feature(
int64_list=tf.train.Int64List(value=[label])
)
}
example = tf.train.Example(
features=tf.train.Features(feature=feature_dict)
)
return example.SerializeToString()
# Write to TFRecord file
with tf.io.TFRecordWriter('sample.tfrecord') as writer:
for i in range(100):
feature = np.random.randn(10).astype(np.float32)
label = np.random.randint(0, 3)
example = create_example(feature, label)
writer.write(example)
# Parse function
def parse_example(example_proto):
feature_description = {
'feature': tf.io.FixedLenFeature([10], tf.float32),
'label': tf.io.FixedLenFeature([], tf.int64)
}
return tf.io.parse_single_example(example_proto, feature_description)
# Create dataset from TFRecord
tfrecord_dataset = tf.data.TFRecordDataset('sample.tfrecord')
parsed_dataset = tfrecord_dataset.map(parse_example)
print("TFRecord dataset:")
for element in parsed_dataset.take(3):
print(f"Feature shape: {element['feature'].shape}, Label: {element['label'].numpy()}")
return parsed_dataset
# create_tfrecord_dataset()Dataset Transformations
1. Basic Transformations
python
# Create base dataset
dataset = tf.data.Dataset.range(10)
# map transformation: apply function to each element
squared_dataset = dataset.map(lambda x: x ** 2)
print("Square transformation:")
for element in squared_dataset:
print(element.numpy())
# filter transformation: filter elements
even_dataset = dataset.filter(lambda x: x % 2 == 0)
print("\nEven number filter:")
for element in even_dataset:
print(element.numpy())
# take transformation: take first n elements
first_five = dataset.take(5)
print("\nFirst 5 elements:")
for element in first_five:
print(element.numpy())
# skip transformation: skip first n elements
skip_five = dataset.skip(5)
print("\nSkip first 5 elements:")
for element in skip_five:
print(element.numpy())2. Batching and Repeating
python
# Create sample dataset
features = tf.random.normal([100, 10])
labels = tf.random.uniform([100], maxval=3, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# Batching
batched_dataset = dataset.batch(32)
print("Shape after batching:")
for batch_features, batch_labels in batched_dataset.take(1):
print(f"Feature batch shape: {batch_features.shape}")
print(f"Label batch shape: {batch_labels.shape}")
# Repeat dataset
repeated_dataset = dataset.repeat(3) # Repeat 3 times
print(f"\nDataset size after repeating: {len(list(repeated_dataset))}")
# Infinite repeat
infinite_dataset = dataset.repeat() # Infinite repeat
# Shuffle data
shuffled_dataset = dataset.shuffle(buffer_size=100)
print("\nShuffled dataset:")
for features, labels in shuffled_dataset.take(3):
print(f"Labels: {labels.numpy()}")3. Complex Transformations
python
# Create sequence dataset
sequence_dataset = tf.data.Dataset.range(20)
# window transformation: create sliding windows
windowed_dataset = sequence_dataset.window(size=5, shift=2, drop_remainder=True)
windowed_dataset = windowed_dataset.flat_map(lambda window: window.batch(5))
print("Sliding window:")
for window in windowed_dataset.take(3):
print(window.numpy())
# flat_map transformation: flatten nested structures
nested_dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4], [5, 6]])
flattened_dataset = nested_dataset.flat_map(tf.data.Dataset.from_tensor_slices)
print("\nFlatten nested structure:")
for element in flattened_dataset:
print(element.numpy())
# zip transformation: combine multiple datasets
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5, 10)
zipped_dataset = tf.data.Dataset.zip((dataset1, dataset2))
print("\nCombined dataset:")
for x, y in zipped_dataset:
print(f"({x.numpy()}, {y.numpy()})")4. Data Preprocessing
python
def preprocess_image_data():
"""Image data preprocessing example"""
# Simulate image path dataset
image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg']
labels = [0, 1, 2]
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
def load_and_preprocess_image(path, label):
# In real applications, this would load actual images
# image = tf.io.read_file(path)
# image = tf.image.decode_image(image, channels=3)
# Simulate image data
image = tf.random.normal([224, 224, 3])
# Image preprocessing
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.2)
image = tf.image.per_image_standardization(image)
return image, label
# Apply preprocessing
processed_dataset = dataset.map(
load_and_preprocess_image,
num_parallel_calls=tf.data.AUTOTUNE
)
return processed_dataset
def preprocess_text_data():
"""Text data preprocessing example"""
texts = ["hello world", "tensorflow is great", "deep learning rocks"]
labels = [0, 1, 1]
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
# Create vocabulary
vocab = ["hello", "world", "tensorflow", "is", "great", "deep", "learning", "rocks"]
vocab_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
keys=vocab,
values=tf.range(len(vocab), dtype=tf.int64)
),
default_value=-1
)
def preprocess_text(text, label):
# Tokenization (simplified version)
words = tf.strings.split(text)
# Vocabulary mapping
word_ids = vocab_table.lookup(words)
# Pad to fixed length
word_ids = tf.pad(word_ids, [[0, 10 - tf.shape(word_ids)[0]]])[:10]
return word_ids, label
processed_dataset = dataset.map(preprocess_text)
return processed_dataset
# Test preprocessing
# image_dataset = preprocess_image_data()
# text_dataset = preprocess_text_data()Performance Optimization
1. Parallel Processing
python
# Create large dataset for performance testing
large_dataset = tf.data.Dataset.range(10000)
def slow_function(x):
"""Simulate time-consuming operation"""
tf.py_function(lambda: tf.numpy_function(lambda x: x**2, [x], tf.int64), [], tf.int64)
return x ** 2
# Serial processing
serial_dataset = large_dataset.map(slow_function)
# Parallel processing
parallel_dataset = large_dataset.map(
slow_function,
num_parallel_calls=tf.data.AUTOTUNE # Auto-adjust parallelism
)
# Manually specify parallelism
manual_parallel_dataset = large_dataset.map(
slow_function,
num_parallel_calls=4 # Use 4 parallel calls
)
print("Parallel processing configured")2. Prefetching and Caching
python
def optimize_dataset_performance(dataset):
"""Dataset performance optimization"""
# Cache dataset (suitable for small datasets)
cached_dataset = dataset.cache()
# Prefetch data
prefetched_dataset = dataset.prefetch(tf.data.AUTOTUNE)
# Combined optimization
optimized_dataset = (dataset
.cache() # Cache
.shuffle(1000) # Shuffle
.batch(32) # Batch
.prefetch(tf.data.AUTOTUNE) # Prefetch
)
return optimized_dataset
# Performance comparison test
def benchmark_dataset(dataset, num_epochs=3):
"""Dataset performance benchmark"""
import time
start_time = time.time()
for epoch in range(num_epochs):
for batch in dataset:
# Simulate training step
pass
end_time = time.time()
return end_time - start_time
# Create test dataset
test_dataset = tf.data.Dataset.range(1000).map(lambda x: x**2)
# Base dataset
basic_time = benchmark_dataset(test_dataset.batch(32))
# Optimized dataset
optimized_dataset = optimize_dataset_performance(test_dataset)
optimized_time = benchmark_dataset(optimized_dataset)
print(f"Base dataset time: {basic_time:.2f} seconds")
print(f"Optimized dataset time: {optimized_time:.2f} seconds")
print(f"Performance improvement: {basic_time/optimized_time:.2f}x")3. Memory Optimization
python
def memory_efficient_dataset():
"""Memory efficient dataset processing"""
# Use generator to avoid loading all data at once
def data_generator():
for i in range(100000): # Large dataset
yield np.random.randn(100).astype(np.float32), i % 10
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(100,), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
# Stream processing to avoid memory overflow
processed_dataset = (dataset
.map(lambda x, y: (tf.nn.l2_normalize(x), y))
.batch(32)
.prefetch(2) # Only prefetch 2 batches
)
return processed_dataset
# memory_efficient_dataset()Data Augmentation
1. Image Data Augmentation
python
def image_augmentation_pipeline():
"""Image data augmentation pipeline"""
def augment_image(image, label):
# Random flip
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
# Random rotation
image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
# Color adjustment
image = tf.image.random_brightness(image, 0.2)
image = tf.image.random_contrast(image, 0.8, 1.2)
image = tf.image.random_saturation(image, 0.8, 1.2)
image = tf.image.random_hue(image, 0.1)
# Random crop and resize
image = tf.image.random_crop(image, [200, 200, 3])
image = tf.image.resize(image, [224, 224])
# Normalization
image = tf.cast(image, tf.float32) / 255.0
image = tf.image.per_image_standardization(image)
return image, label
# Create simulated image dataset
images = tf.random.normal([100, 224, 224, 3])
labels = tf.random.uniform([100], maxval=10, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
# Apply data augmentation
augmented_dataset = dataset.map(
augment_image,
num_parallel_calls=tf.data.AUTOTUNE
)
return augmented_dataset
# augmented_dataset = image_augmentation_pipeline()2. Text Data Augmentation
python
def text_augmentation_pipeline():
"""Text data augmentation pipeline"""
def augment_text(text, label):
# Random word replacement (simplified version)
words = tf.strings.split(text)
# Random word deletion
num_words = tf.shape(words)[0]
keep_prob = 0.9
mask = tf.random.uniform([num_words]) < keep_prob
filtered_words = tf.boolean_mask(words, mask)
# Rejoin
augmented_text = tf.strings.join(filtered_words, separator=' ')
return augmented_text, label
# Sample text data
texts = tf.constant([
"this is a great movie",
"terrible film not recommended",
"amazing story and acting"
])
labels = tf.constant([1, 0, 1])
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
augmented_dataset = dataset.map(augment_text)
return augmented_dataset
# text_augmented_dataset = text_augmentation_pipeline()Data Validation and Debugging
1. Dataset Inspection
python
def inspect_dataset(dataset, name="Dataset"):
"""Check dataset content and structure"""
print(f"\n=== {name} Inspection ===")
# Check dataset structure
print(f"Element specification: {dataset.element_spec}")
# View first few samples
print("First 3 samples:")
for i, element in enumerate(dataset.take(3)):
if isinstance(element, tuple):
print(f"Sample {i}: Feature shape={element[0].shape}, Label={element[1].numpy()}")
else:
print(f"Sample {i}: {element.numpy()}")
# Statistics
try:
cardinality = tf.data.experimental.cardinality(dataset)
if cardinality != tf.data.experimental.UNKNOWN_CARDINALITY:
print(f"Dataset size: {cardinality.numpy()}")
else:
print("Dataset size: Unknown")
except:
print("Unable to get dataset size")
# Create test dataset
test_features = tf.random.normal([50, 10])
test_labels = tf.random.uniform([50], maxval=3, dtype=tf.int32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels))
inspect_dataset(test_dataset, "Test Dataset")
inspect_dataset(test_dataset.batch(8), "Batched Dataset")2. Data Quality Check
python
def validate_data_quality(dataset):
"""Data quality validation"""
def check_data_quality(features, labels):
# Check for NaN values
has_nan = tf.reduce_any(tf.math.is_nan(features))
# Check for infinite values
has_inf = tf.reduce_any(tf.math.is_inf(features))
# Check label range
valid_labels = tf.logical_and(labels >= 0, labels < 10)
# Log issues
tf.cond(
has_nan,
lambda: tf.print("Warning: NaN values detected"),
lambda: tf.no_op()
)
tf.cond(
has_inf,
lambda: tf.print("Warning: Infinite values detected"),
lambda: tf.no_op()
)
tf.cond(
tf.reduce_any(tf.logical_not(valid_labels)),
lambda: tf.print("Warning: Invalid labels detected"),
lambda: tf.no_op()
)
return features, labels
validated_dataset = dataset.map(check_data_quality)
return validated_dataset
# Create dataset with issues for testing
problematic_features = tf.constant([[1.0, float('nan')], [float('inf'), 3.0]])
problematic_labels = tf.constant([1, 15]) # Labels out of range
problematic_dataset = tf.data.Dataset.from_tensor_slices((problematic_features, problematic_labels))
# validated_dataset = validate_data_quality(problematic_dataset)Practical Application Examples
1. Complete Image Classification Data Pipeline
python
def create_image_classification_pipeline(image_dir, batch_size=32, image_size=(224, 224)):
"""Create complete image classification data pipeline"""
# Get image paths and labels
def get_image_paths_and_labels(image_dir):
# This should implement actual file scanning logic
# Return list of image paths and corresponding labels
paths = ['img1.jpg', 'img2.jpg', 'img3.jpg'] # Example
labels = [0, 1, 2] # Example
return paths, labels
def load_and_preprocess_image(path, label):
# Load image
image = tf.io.read_file(path)
image = tf.image.decode_image(image, channels=3)
image = tf.cast(image, tf.float32)
# Preprocess
image = tf.image.resize(image, image_size)
image = tf.image.per_image_standardization(image)
return image, label
def augment_for_training(image, label):
# Data augmentation during training
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.1)
image = tf.image.random_contrast(image, 0.9, 1.1)
return image, label
# Get data
paths, labels = get_image_paths_and_labels(image_dir)
# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
# Build pipeline
dataset = (dataset
.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.cache() # Cache preprocessing results
.map(augment_for_training, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(1000)
.batch(batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return dataset
# Usage example
# train_dataset = create_image_classification_pipeline('train_images/')
# val_dataset = create_image_classification_pipeline('val_images/')2. Text Classification Data Pipeline
python
def create_text_classification_pipeline(texts, labels, vocab_size=10000, max_length=100):
"""Create text classification data pipeline"""
# Create vocabulary
tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=vocab_size,
oov_token='<OOV>'
)
tokenizer.fit_on_texts(texts)
def preprocess_text(text, label):
# Tokenization and encoding
sequences = tokenizer.texts_to_sequences([text.numpy().decode('utf-8')])
sequence = tf.constant(sequences[0], dtype=tf.int32)
# Pad or truncate to fixed length
sequence = tf.pad(sequence, [[0, max_length - tf.shape(sequence)[0]]])[:max_length]
return sequence, label
# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
# Preprocessing
dataset = dataset.map(
lambda text, label: tf.py_function(
preprocess_text, [text, label], [tf.int32, tf.int32]
),
num_parallel_calls=tf.data.AUTOTUNE
)
return dataset, tokenizer
# Example usage
sample_texts = ["这是一个好电影", "糟糕的体验", "非常推荐"]
sample_labels = [1, 0, 1]
# text_dataset, text_tokenizer = create_text_classification_pipeline(sample_texts, sample_labels)Summary
The tf.data API is a powerful tool in TensorFlow for data processing, with main features including:
- Multiple Data Sources: Supports memory, files, generators, and various data sources
- Rich Transformation Operations: map, filter, batch, shuffle, etc.
- Performance Optimization: Parallel processing, prefetching, caching, and other optimization techniques
- Data Augmentation: Built-in image and text augmentation capabilities
- Easy Debugging: Provides data inspection and validation tools
Mastering the tf.data API will greatly improve your data processing efficiency and model training performance!