Skip to content

Image Classification Project

This chapter will demonstrate how to build, train, and deploy a practical deep learning model for image classification using TensorFlow through a complete project. We will start from data preparation and gradually complete the entire machine learning workflow.

Project Overview

We will build an image classifier capable of recognizing different animals, using the CIFAR-10 dataset as an example, and then extend to custom datasets.

Project Goals

  • Build high-accuracy image classification models
  • Learn data preprocessing and augmentation techniques
  • Master model training and tuning methods
  • Implement model evaluation and visualization
  • Deploy models for practical applications
python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import os
import cv2
from pathlib import Path

# Set random seeds
tf.random.set_seed(42)
np.random.seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

Data Preparation

Loading CIFAR-10 Dataset

python
def load_cifar10_data():
    """
    Load and preprocess CIFAR-10 dataset
    """
    # Load data
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

    # Class names
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

    # Data information
    print(f"Training set shape: {x_train.shape}")
    print(f"Test set shape: {x_test.shape}")
    print(f"Number of classes: {len(class_names)}")

    return (x_train, y_train), (x_test, y_test), class_names

def preprocess_data(x_train, y_train, x_test, y_test):
    """
    Data preprocessing
    """
    # Normalize pixel values to [0,1]
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Convert labels to categorical format
    num_classes = len(np.unique(y_train))
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    return x_train, y_train, x_test, y_test, num_classes

# Load and preprocess data
(x_train, y_train), (x_test, y_test), class_names = load_cifar10_data()
x_train, y_train, x_test, y_test, num_classes = preprocess_data(
    x_train, y_train, x_test, y_test
)

Data Visualization

python
def visualize_dataset(x_train, y_train, class_names, num_samples=25):
    """
    Visualize dataset samples
    """
    plt.figure(figsize=(12, 12))

    for i in range(num_samples):
        plt.subplot(5, 5, i + 1)
        plt.imshow(x_train[i])
        plt.title(f'{class_names[np.argmax(y_train[i])]}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

def plot_class_distribution(y_train, class_names):
    """
    Plot class distribution
    """
    class_counts = np.sum(y_train, axis=0)

    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, class_counts)
    plt.title('Training Set Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Number of Samples')
    plt.xticks(rotation=45)

    # Add value labels
    for bar, count in zip(bars, class_counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
                f'{int(count)}', ha='center', va='bottom')

    plt.tight_layout()
    plt.show()

# Visualize data
visualize_dataset(x_train, y_train, class_names)
plot_class_distribution(y_train, class_names)

Data Augmentation

python
def create_data_augmentation():
    """
    Create data augmentation pipeline
    """
    data_augmentation = keras.Sequential([
        keras.layers.RandomFlip("horizontal"),
        keras.layers.RandomRotation(0.1),
        keras.layers.RandomZoom(0.1),
        keras.layers.RandomContrast(0.1),
        keras.layers.RandomBrightness(0.1),
    ])

    return data_augmentation

def visualize_augmentation(x_train, data_augmentation):
    """
    Visualize data augmentation effects
    """
    # Select one sample
    sample_image = x_train[0:1]

    plt.figure(figsize=(15, 5))

    # Original image
    plt.subplot(1, 6, 1)
    plt.imshow(sample_image[0])
    plt.title('Original Image')
    plt.axis('off')

    # Augmented images
    for i in range(5):
        augmented_image = data_augmentation(sample_image, training=True)
        plt.subplot(1, 6, i + 2)
        plt.imshow(augmented_image[0])
        plt.title(f'Augmented {i+1}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Create data augmentation
data_augmentation = create_data_augmentation()
visualize_augmentation(x_train, data_augmentation)

Model Building

Basic CNN Model

python
def create_basic_cnn(input_shape, num_classes):
    """
    Create basic CNN model
    """
    model = keras.Sequential([
        # Data augmentation layer
        data_augmentation,

        # First convolution block
        keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(32, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Dropout(0.25),

        # Second convolution block
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Dropout(0.25),

        # Third convolution block
        keras.layers.Conv2D(128, (3, 3), activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(128, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Dropout(0.25),

        # Fully connected layers
        keras.layers.Flatten(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation='softmax')
    ])

    return model

# Create basic model
basic_model = create_basic_cnn((32, 32, 3), num_classes)
basic_model.summary()

ResNet Style Model

python
def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
    """
    Residual block
    """
    if conv_shortcut:
        shortcut = keras.layers.Conv2D(filters, 1, strides=stride)(x)
        shortcut = keras.layers.BatchNormalization()(shortcut)
    else:
        shortcut = x

    x = keras.layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    x = keras.layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)

    x = keras.layers.Add()([shortcut, x])
    x = keras.layers.ReLU()(x)

    return x

def create_resnet_model(input_shape, num_classes):
    """
    Create ResNet style model
    """
    inputs = keras.layers.Input(shape=input_shape)

    # Data augmentation
    x = data_augmentation(inputs)

    # Initial convolution
    x = keras.layers.Conv2D(64, 7, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.MaxPooling2D(3, strides=2, padding='same')(x)

    # Residual block groups
    x = residual_block(x, 64, conv_shortcut=True)
    x = residual_block(x, 64)

    x = residual_block(x, 128, stride=2, conv_shortcut=True)
    x = residual_block(x, 128)

    x = residual_block(x, 256, stride=2, conv_shortcut=True)
    x = residual_block(x, 256)

    # Global average pooling
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.5)(x)

    # Classification layer
    outputs = keras.layers.Dense(num_classes, activation='softmax')(x)

    model = keras.Model(inputs, outputs)
    return model

# Create ResNet model
resnet_model = create_resnet_model((32, 32, 3), num_classes)
resnet_model.summary()

Using Pre-trained Models

python
def create_transfer_learning_model(input_shape, num_classes, base_model_name='EfficientNetB0'):
    """
    Create transfer learning model
    """
    # Load pre-trained model
    if base_model_name == 'EfficientNetB0':
        base_model = keras.applications.EfficientNetB0(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    elif base_model_name == 'ResNet50':
        base_model = keras.applications.ResNet50(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    else:
        raise ValueError(f"Unsupported model: {base_model_name}")

    # Freeze pre-trained layers
    base_model.trainable = False

    # Build complete model
    inputs = keras.layers.Input(shape=input_shape)

    # Data augmentation
    x = data_augmentation(inputs)

    # Preprocessing (resize to pre-trained model's expected input)
    x = keras.layers.Resizing(224, 224)(x)
    x = keras.applications.efficientnet.preprocess_input(x)

    # Pre-trained model
    x = base_model(x, training=False)

    # Custom head
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(num_classes, activation='softmax')(x)

    model = keras.Model(inputs, outputs)
    return model, base_model

# Create transfer learning model
transfer_model, base_model = create_transfer_learning_model((32, 32, 3), num_classes)
transfer_model.summary()

Model Training

Training Configuration

python
def compile_model(model, learning_rate=0.001):
    """
    Compile model
    """
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'top_k_categorical_accuracy']
    )

    return model

def create_callbacks(model_name):
    """
    Create training callbacks
    """
    callbacks = [
        # Model checkpoint
        keras.callbacks.ModelCheckpoint(
            f'best_{model_name}.h5',
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=False,
            verbose=1
        ),

        # Early stopping
        keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),

        # Learning rate scheduler
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),

        # TensorBoard
        keras.callbacks.TensorBoard(
            log_dir=f'logs/{model_name}',
            histogram_freq=1,
            write_graph=True,
            write_images=True
        )
    ]

    return callbacks

# Compile model
basic_model = compile_model(basic_model)
callbacks = create_callbacks('basic_cnn')

Training Process

python
def train_model(model, x_train, y_train, x_test, y_test,
                callbacks, epochs=100, batch_size=32, validation_split=0.2):
    """
    Train model
    """
    history = model.fit(
        x_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_split=validation_split,
        callbacks=callbacks,
        verbose=1
    )

    return history

def plot_training_history(history):
    """
    Plot training history
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0, 0].plot(history.history['loss'], label='Training Loss')
    axes[0, 0].plot(history.history['val_loss'], label='Validation Loss')
    axes[0, 0].set_title('Model Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()

    # Accuracy
    axes[0, 1].plot(history.history['accuracy'], label='Training Accuracy')
    axes[0, 1].plot(history.history['val_accuracy'], label='Validation Accuracy')
    axes[0, 1].set_title('Model Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()

    # Top-K Accuracy
    axes[1, 0].plot(history.history['top_k_categorical_accuracy'], label='Training Top-K Accuracy')
    axes[1, 0].plot(history.history['val_top_k_categorical_accuracy'], label='Validation Top-K Accuracy')
    axes[1, 0].set_title('Top-K Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()

    # Learning rate (if recorded)
    if 'lr' in history.history:
        axes[1, 1].plot(history.history['lr'])
        axes[1, 1].set_title('Learning Rate')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_yscale('log')

    plt.tight_layout()
    plt.show()

# Train model
print("Starting to train basic CNN model...")
history = train_model(basic_model, x_train, y_train, x_test, y_test, callbacks)
plot_training_history(history)

Model Evaluation

Performance Evaluation

python
def evaluate_model(model, x_test, y_test, class_names):
    """
    Comprehensively evaluate model performance
    """
    # Predictions
    y_pred = model.predict(x_test)
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_true_classes = np.argmax(y_test, axis=1)

    # Calculate accuracy
    test_loss, test_accuracy, test_top_k = model.evaluate(x_test, y_test, verbose=0)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Top-K Accuracy: {test_top_k:.4f}")

    # Classification report
    print("\nClassification Report:")
    print(classification_report(y_true_classes, y_pred_classes,
                              target_names=class_names))

    return y_pred, y_pred_classes, y_true_classes

def plot_confusion_matrix(y_true, y_pred, class_names):
    """
    Plot confusion matrix
    """
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Class')
    plt.ylabel('True Class')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

    # Calculate accuracy for each class
    class_accuracy = cm.diagonal() / cm.sum(axis=1)

    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, class_accuracy)
    plt.title('Accuracy by Class')
    plt.xlabel('Class')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)

    # Add value labels
    for bar, acc in zip(bars, class_accuracy):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.3f}', ha='center', va='bottom')

    plt.tight_layout()
    plt.show()

# Evaluate model
y_pred, y_pred_classes, y_true_classes = evaluate_model(
    basic_model, x_test, y_test, class_names
)
plot_confusion_matrix(y_true_classes, y_pred_classes, class_names)

Error Analysis

python
def analyze_errors(x_test, y_true, y_pred, y_pred_classes, class_names, num_examples=20):
    """
    Analyze incorrectly predicted samples
    """
    # Find incorrectly predicted samples
    incorrect_indices = np.where(y_true != y_pred_classes)[0]

    # Randomly select some error samples
    if len(incorrect_indices) > num_examples:
        selected_indices = np.random.choice(incorrect_indices, num_examples, replace=False)
    else:
        selected_indices = incorrect_indices

    # Visualize error samples
    cols = 5
    rows = (len(selected_indices) + cols - 1) // cols

    plt.figure(figsize=(15, 3 * rows))

    for i, idx in enumerate(selected_indices):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(x_test[idx])

        true_label = class_names[y_true[idx]]
        pred_label = class_names[y_pred_classes[idx]]
        confidence = np.max(y_pred[idx])

        plt.title(f'True: {true_label}\nPred: {pred_label}\nConfidence: {confidence:.3f}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

def plot_prediction_confidence(y_pred, y_true, y_pred_classes):
    """
    Analyze prediction confidence distribution
    """
    # Calculate prediction confidence
    confidences = np.max(y_pred, axis=1)

    # Confidence for correct and incorrect predictions
    correct_mask = (y_true == y_pred_classes)
    correct_confidences = confidences[correct_mask]
    incorrect_confidences = confidences[~correct_mask]

    plt.figure(figsize=(12, 5))

    # Confidence distribution
    plt.subplot(1, 2, 1)
    plt.hist(correct_confidences, bins=50, alpha=0.7, label='Correct Prediction', color='green')
    plt.hist(incorrect_confidences, bins=50, alpha=0.7, label='Incorrect Prediction', color='red')
    plt.xlabel('Prediction Confidence')
    plt.ylabel('Frequency')
    plt.title('Prediction Confidence Distribution')
    plt.legend()

    # Confidence vs accuracy
    plt.subplot(1, 2, 2)
    confidence_bins = np.linspace(0, 1, 11)
    bin_accuracies = []
    bin_counts = []

    for i in range(len(confidence_bins) - 1):
        mask = (confidences >= confidence_bins[i]) & (confidences < confidence_bins[i + 1])
        if np.sum(mask) > 0:
            accuracy = np.mean(correct_mask[mask])
            bin_accuracies.append(accuracy)
            bin_counts.append(np.sum(mask))
        else:
            bin_accuracies.append(0)
            bin_counts.append(0)

    bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2
    plt.bar(bin_centers, bin_accuracies, width=0.08, alpha=0.7)
    plt.xlabel('Confidence Interval')
    plt.ylabel('Accuracy')
    plt.title('Confidence vs Accuracy')
    plt.ylim(0, 1)

    plt.tight_layout()
    plt.show()

# Error analysis
analyze_errors(x_test, y_true_classes, y_pred, y_pred_classes, class_names)
plot_prediction_confidence(y_pred, y_true_classes, y_pred_classes)

Model Optimization

Hyperparameter Tuning

python
def hyperparameter_tuning():
    """
    Hyperparameter tuning example
    """
    import keras_tuner as kt

    def build_model(hp):
        model = keras.Sequential()

        # Data augmentation
        model.add(data_augmentation)

        # Convolution layer count and parameter tuning
        for i in range(hp.Int('num_conv_blocks', 2, 4)):
            model.add(keras.layers.Conv2D(
                filters=hp.Int(f'conv_{i}_filters', 32, 256, step=32),
                kernel_size=hp.Choice(f'conv_{i}_kernel', [3, 5]),
                activation='relu',
                padding='same'
            ))
            model.add(keras.layers.BatchNormalization())

            if hp.Boolean(f'conv_{i}_dropout'):
                model.add(keras.layers.Dropout(hp.Float(f'conv_{i}_dropout_rate', 0.1, 0.5)))

            model.add(keras.layers.MaxPooling2D(2))

        # Fully connected layers
        model.add(keras.layers.Flatten())

        for i in range(hp.Int('num_dense_layers', 1, 3)):
            model.add(keras.layers.Dense(
                units=hp.Int(f'dense_{i}_units', 128, 1024, step=128),
                activation='relu'
            ))
            model.add(keras.layers.Dropout(hp.Float(f'dense_{i}_dropout', 0.2, 0.7)))

        model.add(keras.layers.Dense(num_classes, activation='softmax'))

        # Compile model
        model.compile(
            optimizer=keras.optimizers.Adam(hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

        return model

    # Create tuner
    tuner = kt.RandomSearch(
        build_model,
        objective='val_accuracy',
        max_trials=20,
        directory='hyperparameter_tuning',
        project_name='cifar10_classification'
    )

    # Search for best hyperparameters
    tuner.search(x_train, y_train,
                epochs=10,
                validation_split=0.2,
                verbose=1)

    # Get best model
    best_model = tuner.get_best_models(num_models=1)[0]
    best_hyperparameters = tuner.get_best_hyperparameters(num_trials=1)[0]

    return best_model, best_hyperparameters

# Note: Actual running requires installing keras-tuner
# pip install keras-tuner

Model Ensemble

python
def create_ensemble_model(models, x_test, y_test):
    """
    Create model ensemble
    """
    predictions = []

    for model in models:
        pred = model.predict(x_test)
        predictions.append(pred)

    # Average ensemble
    ensemble_pred = np.mean(predictions, axis=0)
    ensemble_classes = np.argmax(ensemble_pred, axis=1)

    # Voting ensemble
    individual_classes = [np.argmax(pred, axis=1) for pred in predictions]
    voting_pred = np.array([np.bincount(votes).argmax()
                           for votes in zip(*individual_classes)])

    # Evaluate ensemble performance
    y_true = np.argmax(y_test, axis=1)

    ensemble_accuracy = np.mean(ensemble_classes == y_true)
    voting_accuracy = np.mean(voting_pred == y_true)

    print(f"Average ensemble accuracy: {ensemble_accuracy:.4f}")
    print(f"Voting ensemble accuracy: {voting_accuracy:.4f}")

    return ensemble_pred, voting_pred

# If you have multiple trained models, you can ensemble them
# ensemble_pred, voting_pred = create_ensemble_model([model1, model2, model3], x_test, y_test)

Custom Dataset

Data Loading and Preprocessing

python
def load_custom_dataset(data_dir, img_size=(224, 224), batch_size=32):
    """
    Load custom dataset
    """
    # Create data generator
    datagen = keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        validation_split=0.2
    )

    # Training data
    train_generator = datagen.flow_from_directory(
        data_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training'
    )

    # Validation data
    validation_generator = datagen.flow_from_directory(
        data_dir,
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation'
    )

    return train_generator, validation_generator

def create_tf_dataset(data_dir, img_size=(224, 224), batch_size=32):
    """
    Create dataset using tf.data
    """
    # Create dataset
    train_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=img_size,
        batch_size=batch_size
    )

    val_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=img_size,
        batch_size=batch_size
    )

    # Data preprocessing
    normalization_layer = keras.layers.Rescaling(1./255)

    train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
    val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))

    # Performance optimization
    AUTOTUNE = tf.data.AUTOTUNE
    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

    return train_ds, val_ds

# Usage example
# train_ds, val_ds = create_tf_dataset('path/to/your/dataset')

Model Deployment

Model Saving and Loading

python
def save_model(model, model_path):
    """
    Save model
    """
    # Save complete model
    model.save(f'{model_path}.h5')

    # Save as SavedModel format
    model.save(f'{model_path}_savedmodel')

    # Save weights
    model.save_weights(f'{model_path}_weights.h5')

    print(f"Model saved to: {model_path}")

def load_model(model_path):
    """
    Load model
    """
    model = keras.models.load_model(f'{model_path}.h5')
    return model

def convert_to_tflite(model, model_path):
    """
    Convert to TensorFlow Lite format
    """
    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    # Optimization options
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # Convert
    tflite_model = converter.convert()

    # Save
    with open(f'{model_path}.tflite', 'wb') as f:
        f.write(tflite_model)

    print(f"TFLite model saved to: {model_path}.tflite")

# Save model
save_model(basic_model, 'cifar10_classifier')
convert_to_tflite(basic_model, 'cifar10_classifier')

Inference Function

python
def create_prediction_function(model, class_names):
    """
    Create prediction function
    """
    def predict_image(image_path):
        # Load and preprocess image
        img = keras.utils.load_img(image_path, target_size=(32, 32))
        img_array = keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0) / 255.0

        # Predict
        predictions = model.predict(img_array)
        predicted_class = class_names[np.argmax(predictions[0])]
        confidence = float(np.max(predictions[0]))

        # Get top-3 predictions
        top_3_indices = np.argsort(predictions[0])[-3:][::-1]
        top_3_predictions = [(class_names[i], float(predictions[0][i]))
                           for i in top_3_indices]

        return {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'top_3_predictions': top_3_predictions
        }

    return predict_image

def batch_predict(model, image_paths, class_names):
    """
    Batch prediction
    """
    results = []
    predict_fn = create_prediction_function(model, class_names)

    for image_path in image_paths:
        try:
            result = predict_fn(image_path)
            result['image_path'] = image_path
            results.append(result)
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")

    return results

# Create prediction function
predict_fn = create_prediction_function(basic_model, class_names)

# Example usage
# result = predict_fn('path/to/test/image.jpg')
# print(result)

Web Application Deployment

python
def create_flask_app(model, class_names):
    """
    Create Flask web application
    """
    from flask import Flask, request, jsonify, render_template
    import base64
    from io import BytesIO
    from PIL import Image

    app = Flask(__name__)

    @app.route('/')
    def index():
        return render_template('index.html')

    @app.route('/predict', methods=['POST'])
    def predict():
        try:
            # Get uploaded image
            if 'file' not in request.files:
                return jsonify({'error': 'No file uploaded'})

            file = request.files['file']
            if file.filename == '':
                return jsonify({'error': 'No file selected'})

            # Process image
            img = Image.open(file.stream)
            img = img.resize((32, 32))
            img_array = np.array(img) / 255.0
            img_array = np.expand_dims(img_array, 0)

            # Predict
            predictions = model.predict(img_array)
            predicted_class = class_names[np.argmax(predictions[0])]
            confidence = float(np.max(predictions[0]))

            # Get probabilities for all classes
            all_predictions = {class_names[i]: float(predictions[0][i])
                             for i in range(len(class_names))}

            return jsonify({
                'predicted_class': predicted_class,
                'confidence': confidence,
                'all_predictions': all_predictions
            })

        except Exception as e:
            return jsonify({'error': str(e)})

    return app

# Create Flask application
# app = create_flask_app(basic_model, class_names)
# app.run(debug=True)

Summary

This chapter demonstrated the complete workflow of a deep learning project through a complete image classification project:

Key Points:

  1. Data Preparation: Data loading, preprocessing, and visualization
  2. Data Augmentation: Improve model generalization capability
  3. Model Design: From basic CNN to transfer learning
  4. Training Optimization: Callback functions, hyperparameter tuning
  5. Model Evaluation: Multi-dimensional performance analysis
  6. Deployment Application: Model saving, inference, and web applications

Best Practices:

  • Fully understand data characteristics
  • Design reasonable data augmentation strategies
  • Select appropriate model architecture
  • Use appropriate evaluation metrics
  • Conduct error analysis and model optimization
  • Consider actual deployment requirements

Next chapter we will learn text classification projects and explore natural language processing applications.

Content is for learning and research only.