Skip to content

TensorFlow Data Processing

tf.data API Introduction

The tf.data API is a core tool in TensorFlow for building efficient data input pipelines. It provides a powerful set of tools to load, transform, and batch data, making it particularly suitable for processing large-scale datasets.

python
import tensorflow as tf
import numpy as np
import pandas as pd
from pathlib import Path

print(f"TensorFlow version: {tf.__version__}")

Creating Datasets

1. Create from Memory Data

python
# Create dataset from tensors
data = tf.constant([1, 2, 3, 4, 5])
dataset = tf.data.Dataset.from_tensor_slices(data)

print("Dataset created from tensors:")
for element in dataset:
    print(element.numpy())

# Create from multiple tensors
features = tf.constant([[1, 2], [3, 4], [5, 6]])
labels = tf.constant([0, 1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

print("\nFeature-label dataset:")
for feature, label in dataset:
    print(f"Feature: {feature.numpy()}, Label: {label.numpy()}")

# Create from dictionaries
dataset_dict = tf.data.Dataset.from_tensor_slices({
    'features': [[1, 2], [3, 4], [5, 6]],
    'labels': [0, 1, 0]
})

print("\nDictionary format dataset:")
for element in dataset_dict:
    print(f"Feature: {element['features'].numpy()}, Label: {element['labels'].numpy()}")

2. Create from Generators

python
def data_generator():
    """Data generator function"""
    for i in range(10):
        yield i, i**2

# Create dataset from generator
dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_signature=(
        tf.TensorSpec(shape=(), dtype=tf.int32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
)

print("Dataset created from generator:")
for x, y in dataset.take(5):
    print(f"x: {x.numpy()}, y: {y.numpy()}")

# More complex generator
def complex_generator():
    """Complex data generator"""
    for i in range(100):
        # Simulate complex data generation logic
        features = np.random.randn(10).astype(np.float32)
        label = np.random.randint(0, 3)
        yield features, label

complex_dataset = tf.data.Dataset.from_generator(
    complex_generator,
    output_signature=(
        tf.TensorSpec(shape=(10,), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
)

3. Create from Files

python
# Create from text files
def create_text_dataset():
    # Create sample text file
    text_data = ["line 1", "line 2", "line 3", "line 4", "line 5"]
    with open('sample.txt', 'w') as f:
        for line in text_data:
            f.write(line + '\n')

    # Create dataset from text file
    text_dataset = tf.data.TextLineDataset('sample.txt')

    print("Text file dataset:")
    for line in text_dataset:
        print(line.numpy().decode('utf-8'))

    return text_dataset

# create_text_dataset()

# Create from CSV files
def create_csv_dataset():
    # Create sample CSV data
    csv_data = pd.DataFrame({
        'feature1': np.random.randn(100),
        'feature2': np.random.randn(100),
        'label': np.random.randint(0, 3, 100)
    })
    csv_data.to_csv('sample.csv', index=False)

    # Create dataset from CSV
    csv_dataset = tf.data.experimental.make_csv_dataset(
        'sample.csv',
        batch_size=5,
        label_name='label',
        num_epochs=1,
        shuffle=False
    )

    print("CSV dataset:")
    for batch in csv_dataset.take(2):
        print("Features:", {k: v.numpy() for k, v in batch[0].items()})
        print("Labels:", batch[1].numpy())
        print()

    return csv_dataset

# create_csv_dataset()

4. Create from TFRecord Files

python
def create_tfrecord_dataset():
    """Create and read TFRecord dataset"""

    # Create TFRecord file
    def create_example(feature, label):
        """Create tf.train.Example"""
        feature_dict = {
            'feature': tf.train.Feature(
                float_list=tf.train.FloatList(value=feature)
            ),
            'label': tf.train.Feature(
                int64_list=tf.train.Int64List(value=[label])
            )
        }
        example = tf.train.Example(
            features=tf.train.Features(feature=feature_dict)
        )
        return example.SerializeToString()

    # Write to TFRecord file
    with tf.io.TFRecordWriter('sample.tfrecord') as writer:
        for i in range(100):
            feature = np.random.randn(10).astype(np.float32)
            label = np.random.randint(0, 3)
            example = create_example(feature, label)
            writer.write(example)

    # Parse function
    def parse_example(example_proto):
        feature_description = {
            'feature': tf.io.FixedLenFeature([10], tf.float32),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }
        return tf.io.parse_single_example(example_proto, feature_description)

    # Create dataset from TFRecord
    tfrecord_dataset = tf.data.TFRecordDataset('sample.tfrecord')
    parsed_dataset = tfrecord_dataset.map(parse_example)

    print("TFRecord dataset:")
    for element in parsed_dataset.take(3):
        print(f"Feature shape: {element['feature'].shape}, Label: {element['label'].numpy()}")

    return parsed_dataset

# create_tfrecord_dataset()

Dataset Transformations

1. Basic Transformations

python
# Create base dataset
dataset = tf.data.Dataset.range(10)

# map transformation: apply function to each element
squared_dataset = dataset.map(lambda x: x ** 2)
print("Square transformation:")
for element in squared_dataset:
    print(element.numpy())

# filter transformation: filter elements
even_dataset = dataset.filter(lambda x: x % 2 == 0)
print("\nEven number filter:")
for element in even_dataset:
    print(element.numpy())

# take transformation: take first n elements
first_five = dataset.take(5)
print("\nFirst 5 elements:")
for element in first_five:
    print(element.numpy())

# skip transformation: skip first n elements
skip_five = dataset.skip(5)
print("\nSkip first 5 elements:")
for element in skip_five:
    print(element.numpy())

2. Batching and Repeating

python
# Create sample dataset
features = tf.random.normal([100, 10])
labels = tf.random.uniform([100], maxval=3, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# Batching
batched_dataset = dataset.batch(32)
print("Shape after batching:")
for batch_features, batch_labels in batched_dataset.take(1):
    print(f"Feature batch shape: {batch_features.shape}")
    print(f"Label batch shape: {batch_labels.shape}")

# Repeat dataset
repeated_dataset = dataset.repeat(3)  # Repeat 3 times
print(f"\nDataset size after repeating: {len(list(repeated_dataset))}")

# Infinite repeat
infinite_dataset = dataset.repeat()  # Infinite repeat

# Shuffle data
shuffled_dataset = dataset.shuffle(buffer_size=100)
print("\nShuffled dataset:")
for features, labels in shuffled_dataset.take(3):
    print(f"Labels: {labels.numpy()}")

3. Complex Transformations

python
# Create sequence dataset
sequence_dataset = tf.data.Dataset.range(20)

# window transformation: create sliding windows
windowed_dataset = sequence_dataset.window(size=5, shift=2, drop_remainder=True)
windowed_dataset = windowed_dataset.flat_map(lambda window: window.batch(5))

print("Sliding window:")
for window in windowed_dataset.take(3):
    print(window.numpy())

# flat_map transformation: flatten nested structures
nested_dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4], [5, 6]])
flattened_dataset = nested_dataset.flat_map(tf.data.Dataset.from_tensor_slices)

print("\nFlatten nested structure:")
for element in flattened_dataset:
    print(element.numpy())

# zip transformation: combine multiple datasets
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5, 10)
zipped_dataset = tf.data.Dataset.zip((dataset1, dataset2))

print("\nCombined dataset:")
for x, y in zipped_dataset:
    print(f"({x.numpy()}, {y.numpy()})")

4. Data Preprocessing

python
def preprocess_image_data():
    """Image data preprocessing example"""

    # Simulate image path dataset
    image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg']
    labels = [0, 1, 2]

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))

    def load_and_preprocess_image(path, label):
        # In real applications, this would load actual images
        # image = tf.io.read_file(path)
        # image = tf.image.decode_image(image, channels=3)

        # Simulate image data
        image = tf.random.normal([224, 224, 3])

        # Image preprocessing
        image = tf.cast(image, tf.float32)
        image = tf.image.resize(image, [224, 224])
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.per_image_standardization(image)

        return image, label

    # Apply preprocessing
    processed_dataset = dataset.map(
        load_and_preprocess_image,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return processed_dataset

def preprocess_text_data():
    """Text data preprocessing example"""

    texts = ["hello world", "tensorflow is great", "deep learning rocks"]
    labels = [0, 1, 1]

    dataset = tf.data.Dataset.from_tensor_slices((texts, labels))

    # Create vocabulary
    vocab = ["hello", "world", "tensorflow", "is", "great", "deep", "learning", "rocks"]
    vocab_table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=vocab,
            values=tf.range(len(vocab), dtype=tf.int64)
        ),
        default_value=-1
    )

    def preprocess_text(text, label):
        # Tokenization (simplified version)
        words = tf.strings.split(text)

        # Vocabulary mapping
        word_ids = vocab_table.lookup(words)

        # Pad to fixed length
        word_ids = tf.pad(word_ids, [[0, 10 - tf.shape(word_ids)[0]]])[:10]

        return word_ids, label

    processed_dataset = dataset.map(preprocess_text)

    return processed_dataset

# Test preprocessing
# image_dataset = preprocess_image_data()
# text_dataset = preprocess_text_data()

Performance Optimization

1. Parallel Processing

python
# Create large dataset for performance testing
large_dataset = tf.data.Dataset.range(10000)

def slow_function(x):
    """Simulate time-consuming operation"""
    tf.py_function(lambda: tf.numpy_function(lambda x: x**2, [x], tf.int64), [], tf.int64)
    return x ** 2

# Serial processing
serial_dataset = large_dataset.map(slow_function)

# Parallel processing
parallel_dataset = large_dataset.map(
    slow_function,
    num_parallel_calls=tf.data.AUTOTUNE  # Auto-adjust parallelism
)

# Manually specify parallelism
manual_parallel_dataset = large_dataset.map(
    slow_function,
    num_parallel_calls=4  # Use 4 parallel calls
)

print("Parallel processing configured")

2. Prefetching and Caching

python
def optimize_dataset_performance(dataset):
    """Dataset performance optimization"""

    # Cache dataset (suitable for small datasets)
    cached_dataset = dataset.cache()

    # Prefetch data
    prefetched_dataset = dataset.prefetch(tf.data.AUTOTUNE)

    # Combined optimization
    optimized_dataset = (dataset
                        .cache()                    # Cache
                        .shuffle(1000)              # Shuffle
                        .batch(32)                  # Batch
                        .prefetch(tf.data.AUTOTUNE) # Prefetch
                        )

    return optimized_dataset

# Performance comparison test
def benchmark_dataset(dataset, num_epochs=3):
    """Dataset performance benchmark"""
    import time

    start_time = time.time()

    for epoch in range(num_epochs):
        for batch in dataset:
            # Simulate training step
            pass

    end_time = time.time()
    return end_time - start_time

# Create test dataset
test_dataset = tf.data.Dataset.range(1000).map(lambda x: x**2)

# Base dataset
basic_time = benchmark_dataset(test_dataset.batch(32))

# Optimized dataset
optimized_dataset = optimize_dataset_performance(test_dataset)
optimized_time = benchmark_dataset(optimized_dataset)

print(f"Base dataset time: {basic_time:.2f} seconds")
print(f"Optimized dataset time: {optimized_time:.2f} seconds")
print(f"Performance improvement: {basic_time/optimized_time:.2f}x")

3. Memory Optimization

python
def memory_efficient_dataset():
    """Memory efficient dataset processing"""

    # Use generator to avoid loading all data at once
    def data_generator():
        for i in range(100000):  # Large dataset
            yield np.random.randn(100).astype(np.float32), i % 10

    dataset = tf.data.Dataset.from_generator(
        data_generator,
        output_signature=(
            tf.TensorSpec(shape=(100,), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32)
        )
    )

    # Stream processing to avoid memory overflow
    processed_dataset = (dataset
                        .map(lambda x, y: (tf.nn.l2_normalize(x), y))
                        .batch(32)
                        .prefetch(2)  # Only prefetch 2 batches
                        )

    return processed_dataset

# memory_efficient_dataset()

Data Augmentation

1. Image Data Augmentation

python
def image_augmentation_pipeline():
    """Image data augmentation pipeline"""

    def augment_image(image, label):
        # Random flip
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)

        # Random rotation
        image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))

        # Color adjustment
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.random_contrast(image, 0.8, 1.2)
        image = tf.image.random_saturation(image, 0.8, 1.2)
        image = tf.image.random_hue(image, 0.1)

        # Random crop and resize
        image = tf.image.random_crop(image, [200, 200, 3])
        image = tf.image.resize(image, [224, 224])

        # Normalization
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.per_image_standardization(image)

        return image, label

    # Create simulated image dataset
    images = tf.random.normal([100, 224, 224, 3])
    labels = tf.random.uniform([100], maxval=10, dtype=tf.int32)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))

    # Apply data augmentation
    augmented_dataset = dataset.map(
        augment_image,
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return augmented_dataset

# augmented_dataset = image_augmentation_pipeline()

2. Text Data Augmentation

python
def text_augmentation_pipeline():
    """Text data augmentation pipeline"""

    def augment_text(text, label):
        # Random word replacement (simplified version)
        words = tf.strings.split(text)

        # Random word deletion
        num_words = tf.shape(words)[0]
        keep_prob = 0.9
        mask = tf.random.uniform([num_words]) < keep_prob
        filtered_words = tf.boolean_mask(words, mask)

        # Rejoin
        augmented_text = tf.strings.join(filtered_words, separator=' ')

        return augmented_text, label

    # Sample text data
    texts = tf.constant([
        "this is a great movie",
        "terrible film not recommended",
        "amazing story and acting"
    ])
    labels = tf.constant([1, 0, 1])

    dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
    augmented_dataset = dataset.map(augment_text)

    return augmented_dataset

# text_augmented_dataset = text_augmentation_pipeline()

Data Validation and Debugging

1. Dataset Inspection

python
def inspect_dataset(dataset, name="Dataset"):
    """Check dataset content and structure"""
    print(f"\n=== {name} Inspection ===")

    # Check dataset structure
    print(f"Element specification: {dataset.element_spec}")

    # View first few samples
    print("First 3 samples:")
    for i, element in enumerate(dataset.take(3)):
        if isinstance(element, tuple):
            print(f"Sample {i}: Feature shape={element[0].shape}, Label={element[1].numpy()}")
        else:
            print(f"Sample {i}: {element.numpy()}")

    # Statistics
    try:
        cardinality = tf.data.experimental.cardinality(dataset)
        if cardinality != tf.data.experimental.UNKNOWN_CARDINALITY:
            print(f"Dataset size: {cardinality.numpy()}")
        else:
            print("Dataset size: Unknown")
    except:
        print("Unable to get dataset size")

# Create test dataset
test_features = tf.random.normal([50, 10])
test_labels = tf.random.uniform([50], maxval=3, dtype=tf.int32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_labels))

inspect_dataset(test_dataset, "Test Dataset")
inspect_dataset(test_dataset.batch(8), "Batched Dataset")

2. Data Quality Check

python
def validate_data_quality(dataset):
    """Data quality validation"""

    def check_data_quality(features, labels):
        # Check for NaN values
        has_nan = tf.reduce_any(tf.math.is_nan(features))

        # Check for infinite values
        has_inf = tf.reduce_any(tf.math.is_inf(features))

        # Check label range
        valid_labels = tf.logical_and(labels >= 0, labels < 10)

        # Log issues
        tf.cond(
            has_nan,
            lambda: tf.print("Warning: NaN values detected"),
            lambda: tf.no_op()
        )

        tf.cond(
            has_inf,
            lambda: tf.print("Warning: Infinite values detected"),
            lambda: tf.no_op()
        )

        tf.cond(
            tf.reduce_any(tf.logical_not(valid_labels)),
            lambda: tf.print("Warning: Invalid labels detected"),
            lambda: tf.no_op()
        )

        return features, labels

    validated_dataset = dataset.map(check_data_quality)
    return validated_dataset

# Create dataset with issues for testing
problematic_features = tf.constant([[1.0, float('nan')], [float('inf'), 3.0]])
problematic_labels = tf.constant([1, 15])  # Labels out of range
problematic_dataset = tf.data.Dataset.from_tensor_slices((problematic_features, problematic_labels))

# validated_dataset = validate_data_quality(problematic_dataset)

Practical Application Examples

1. Complete Image Classification Data Pipeline

python
def create_image_classification_pipeline(image_dir, batch_size=32, image_size=(224, 224)):
    """Create complete image classification data pipeline"""

    # Get image paths and labels
    def get_image_paths_and_labels(image_dir):
        # This should implement actual file scanning logic
        # Return list of image paths and corresponding labels
        paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']  # Example
        labels = [0, 1, 2]  # Example
        return paths, labels

    def load_and_preprocess_image(path, label):
        # Load image
        image = tf.io.read_file(path)
        image = tf.image.decode_image(image, channels=3)
        image = tf.cast(image, tf.float32)

        # Preprocess
        image = tf.image.resize(image, image_size)
        image = tf.image.per_image_standardization(image)

        return image, label

    def augment_for_training(image, label):
        # Data augmentation during training
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.1)
        image = tf.image.random_contrast(image, 0.9, 1.1)
        return image, label

    # Get data
    paths, labels = get_image_paths_and_labels(image_dir)

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))

    # Build pipeline
    dataset = (dataset
              .map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
              .cache()  # Cache preprocessing results
              .map(augment_for_training, num_parallel_calls=tf.data.AUTOTUNE)
              .shuffle(1000)
              .batch(batch_size)
              .prefetch(tf.data.AUTOTUNE)
              )

    return dataset

# Usage example
# train_dataset = create_image_classification_pipeline('train_images/')
# val_dataset = create_image_classification_pipeline('val_images/')

2. Text Classification Data Pipeline

python
def create_text_classification_pipeline(texts, labels, vocab_size=10000, max_length=100):
    """Create text classification data pipeline"""

    # Create vocabulary
    tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=vocab_size,
        oov_token='<OOV>'
    )
    tokenizer.fit_on_texts(texts)

    def preprocess_text(text, label):
        # Tokenization and encoding
        sequences = tokenizer.texts_to_sequences([text.numpy().decode('utf-8')])
        sequence = tf.constant(sequences[0], dtype=tf.int32)

        # Pad or truncate to fixed length
        sequence = tf.pad(sequence, [[0, max_length - tf.shape(sequence)[0]]])[:max_length]

        return sequence, label

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices((texts, labels))

    # Preprocessing
    dataset = dataset.map(
        lambda text, label: tf.py_function(
            preprocess_text, [text, label], [tf.int32, tf.int32]
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return dataset, tokenizer

# Example usage
sample_texts = ["这是一个好电影", "糟糕的体验", "非常推荐"]
sample_labels = [1, 0, 1]

# text_dataset, text_tokenizer = create_text_classification_pipeline(sample_texts, sample_labels)

Summary

The tf.data API is a powerful tool in TensorFlow for data processing, with main features including:

  1. Multiple Data Sources: Supports memory, files, generators, and various data sources
  2. Rich Transformation Operations: map, filter, batch, shuffle, etc.
  3. Performance Optimization: Parallel processing, prefetching, caching, and other optimization techniques
  4. Data Augmentation: Built-in image and text augmentation capabilities
  5. Easy Debugging: Provides data inspection and validation tools

Mastering the tf.data API will greatly improve your data processing efficiency and model training performance!

Content is for learning and research only.