Convolutional Neural Network Bird Classifier
Stack
Here are the technologies used in this project:
Overview
- Developed a custom bird classification model using PyTorch with an accuracy of 87% across 20 bird species from a dataset of 3,113 images.
- Focused on optimizing generalization and reducing overfitting, achieving improved performance on unseen test data.
Model Architecture
- Architecture Selection:
- Experimented with Variational Autoencoder (VAE) and Generative Adversarial Network (GAN).
- Chose a Convolutional Neural Network (CNN) due to its:
- Superior feature detection capabilities.
- Efficiency in training.
- Proven effectiveness in handling image data.
Sample Code of the Model
class BirdModel(nn.Module):
def __init__(self, num_classes=20):
super().__init__()
# input = 224 x 224 x 3
self.conv1 = nn.Conv2d(3, 10, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(10, 20, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(20 * 53 * 53, 50)
self.fc2 = nn.Linear(50, 20)
def forward(self, x):
self.to(device)
x = self.pool1(nn.functional.relu(self.conv1(x)))
x = self.pool2(nn.functional.relu(self.conv2(x)))
x = x.reshape(-1, 20 * 53 * 53)
x = self.fc1(x)
x = self.fc2(x)
return x
Training Process
- Data Augmentation Techniques:
- Implemented:
- Image rotation.
- Color fixation changes.
- Random erasing.
- Horizontal and vertical flips.
- Designed to simulate diverse real-world scenarios and improve model robustness.
- Implemented:
- Hyperparameter Tuning:
- Extensive tuning led to a 15% improvement in model precision and recall.
- Loss Functions:
- Optimized loss functions tailored for classification tasks to enhance performance.
Validation and Overfitting
- Dataset Split:
- Used a 70-30 split for training and validation.
- Limitation: Absence of a distinct test set reduces ability to evaluate generalization on unseen data.
- Overfitting Strategies:
- Implemented data augmentation techniques (e.g., random flips, color jittering).
- Acknowledged the potential for advanced strategies like dropout or weight decay.
Error Analysis
- Limited error analysis due to the lack of a test set.
- Recognized the importance of a comprehensive error analysis to:
- Understand misclassifications and biases.
- Identify patterns in errors.
Future Improvements
- Incorporate a Test Set:
- Evaluate performance on completely unseen data for better generalization.
- Conduct Detailed Error Analysis:
- Hypothesize causes of errors and refine the model to enhance robustness.
- Explore Advanced Techniques:
- Introduce regularization and advanced data augmentation to further mitigate overfitting.