I wanted to understand the whole image classification from scratch. Instead of starting with the common MNIST dataset I wanted to do it differently. The Retina dataset interested me and I wanted to understand how the models can be trained using transfer learning.
You can download the Retina dataset from Kaggle here.
import lightning as L
from torchvision import models
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import os
from PIL import Image
from torchvision.transforms import transforms
from sklearn.model_selection import train_test_split
import torch
# Define the dataset
class RetinaDataset(Dataset):
def __init__(self, csv_file, img_dir, transform=None):
# super.__init__()
self.data = pd.read_csv(csv_file)
self.img_dir = img_dir
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_name = os.path.join(self.img_dir, f"{self.data.iloc[idx, 0]}.png")
img = Image.open(img_name).convert('RGB')
label = int(self.data.iloc[idx, 1])
if self.transform:
img = self.transform(img)
return img, label
# Define the transorms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# Load dataset and dataloader
train_dataset = RetinaDataset(csv_file="/home/aarabhi/learning/aptos2019-blindness-detection/train.csv",
img_dir="/home/aarabhi/learning/aptos2019-blindness-detection/train_images",
transform=transform)
# Split dataset - train and val
train_idx, val_idx = train_test_split(list(range(len(train_dataset))), test_size=0.2, random_state=42)
# Create dataloader
train_dataloader = DataLoader(Subset(train_dataset, train_idx), batch_size=32, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
val_dataloader = DataLoader(Subset(train_dataset, val_idx), batch_size=32, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
model = models.resnet18(pretrained=True)
# Freeze all layers
for p in model.parameters():
p.requires_grad = False
# We modify the model to classify for 5 outputs instead of 1000 from ImageNet
num_features = model.fc.in_features
num_classes = 5
model.fc = nn.Linear(num_features, num_classes)
class RetinaClassifier(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.loss_criterion = nn.CrossEntropyLoss()
def training_step(self, batch):
images, labels = batch
outputs = self.model(images)
loss = self.loss_criterion(outputs, labels)
self.log("Train loss", loss)
return loss
def validation_step(self, batch):
images, labels = batch
outputs = self.model(images)
loss = self.loss_criterion(outputs, labels)
self.log("Val loss", loss)
def configure_optimizers(self):
return torch.optim.Adam(model.fc.parameters(), lr=0.001)
model_pl = RetinaClassifier(model)
trainer = L.Trainer(max_epochs=20)
trainer.fit(model_pl, train_dataloader, val_dataloader)