Skip to content
Pablo Rodriguez

Transfer Learning

Transfer Learning: Using Data from Different Tasks

Section titled “Transfer Learning: Using Data from Different Tasks”
Powerful Technique

Transfer learning allows using data from totally different tasks to improve performance on your application, especially when you don’t have much data.

Problem: Recognize digits 0-9 but limited labeled training data

Solution Process:

  1. Start with Large Dataset

    • Find dataset with 1 million images
    • 1,000 different classes (cats, dogs, cars, people, etc.)
    • Train neural network on this large, diverse dataset
  2. Learn General Parameters

    • Network learns W¹, b¹ for first layer
    • W², b² for second layer
    • Continue through W⁴, b⁴
    • W⁵, b⁵ for output layer (1,000 units)
  3. Adapt for Target Task

    • Copy first four layers (W¹,b¹ through W⁴,b⁴)
    • Replace output layer with 10 units (digits 0-9)
    • Initialize new W⁵, b⁵ from scratch
  • Fix parameters: Keep W¹,b¹ through W⁴,b⁴ unchanged
  • Train only: New output layer parameters W⁵,b⁵
  • Best for: Very small target datasets
  • Initialize: Use pre-trained values for W¹,b¹ through W⁴,b⁴
  • Fine-tune: Update all parameters including output layer
  • Best for: Larger target datasets

Small Dataset

Recommendation: Option 1

  • Only update output layer
  • Preserve learned features
  • Prevent overfitting

Larger Dataset

Recommendation: Option 2

  • Fine-tune all parameters
  • Adapt features to new task
  • Better performance potential
  • Train neural network on large dataset (1 million images)
  • Learn general features from diverse tasks
  • Often done by researchers and shared publicly
  • Adapt pre-trained parameters to specific task
  • Run gradient descent on target dataset
  • Adjust weights for handwritten digit recognition

Early layers learn low-level features:

  • Layer 1: Edge detection (lines, curves)
  • Layer 2: Corner and basic shape detection
  • Layer 3: More complex shapes and curves

These features generalize across tasks:

  • Edge detection useful for many vision problems
  • Basic shapes relevant to digits, letters, objects
  • Transfer from cats/dogs/cars helps with digit recognition

Training on diverse objects (cats, dogs, cars, people) teaches networks to detect:

  • Edges and corners: Fundamental visual elements
  • Basic geometric shapes: Building blocks of complex objects
  • Generic visual features: Applicable across many computer vision tasks

Critical requirement: Same input type for pre-training and fine-tuning

  • Pre-training: Images of specified dimensions
  • Fine-tuning: Images of same dimensions
  • ✅ Compatible for transfer

Step 1: Obtain Pre-trained Model

  • Download neural network pre-trained on large dataset
  • Ensure same input type as your application
  • Alternative: Train your own (but downloading usually better)

Step 2: Fine-tune on Your Data

  • Replace output layer for your specific task
  • Apply Option 1 or Option 2 training approach
  • Use much smaller dataset (sometimes even 50 images)
Remarkable Results

“I’ve sometimes trained neural networks on as few as 50 images that worked quite well using this technique, when pre-trained on much larger datasets.”

  • GPT-3: Pre-trained on large text datasets, fine-tuned for specific applications
  • BERT: Pre-trained language model for various NLP tasks
  • ImageNet: Pre-trained vision models for computer vision applications

Transfer learning exemplifies machine learning community collaboration:

  • Researchers share pre-trained models freely
  • Build on each other’s work collectively
  • Achieve better results than individual efforts
  • Open sharing of ideas, code, and parameters

This collaborative approach enables anyone to leverage powerful pre-trained models and fine-tune them for specific applications, dramatically reducing the data requirements for achieving good performance.