Table of Contents

Classify retina images for diabetic retinopathy

I wanted to understand the whole image classification from scratch. Instead of starting with the common MNIST dataset I wanted to do it differently. The Retina dataset interested me and I wanted to understand how the models can be trained using transfer learning.

You can download the Retina dataset from Kaggle here.

import lightning as L
from torchvision import models
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import os
from PIL import Image
from torchvision.transforms import transforms
from sklearn.model_selection import train_test_split
import torch 
# Define the dataset
class RetinaDataset(Dataset):

    def __init__(self, csv_file, img_dir, transform=None):
        # super.__init__()
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, f"{self.data.iloc[idx, 0]}.png")
        img = Image.open(img_name).convert('RGB')
        label = int(self.data.iloc[idx, 1])

        if self.transform:
            img = self.transform(img)
        
        return img, label


# Define the transorms 
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)

])

# Load dataset and dataloader
train_dataset = RetinaDataset(csv_file="/home/aarabhi/learning/aptos2019-blindness-detection/train.csv",
                             img_dir="/home/aarabhi/learning/aptos2019-blindness-detection/train_images",
                             transform=transform)


# Split dataset - train and val

train_idx, val_idx = train_test_split(list(range(len(train_dataset))), test_size=0.2, random_state=42)

# Create dataloader

train_dataloader = DataLoader(Subset(train_dataset, train_idx), batch_size=32, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
val_dataloader = DataLoader(Subset(train_dataset, val_idx), batch_size=32, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)


model = models.resnet18(pretrained=True)

# Freeze all layers
for p in model.parameters():
    p.requires_grad = False

# We modify the model to classify for 5 outputs instead of 1000 from ImageNet
num_features = model.fc.in_features
num_classes = 5
model.fc = nn.Linear(num_features, num_classes)


class RetinaClassifier(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss_criterion = nn.CrossEntropyLoss()

    def training_step(self, batch):
        images, labels = batch
        outputs = self.model(images)
        loss = self.loss_criterion(outputs, labels)
        self.log("Train loss", loss)
        return loss
    
    def validation_step(self, batch):
        images, labels = batch
        outputs = self.model(images)
        loss = self.loss_criterion(outputs, labels)
        self.log("Val loss", loss)
       
        



    def configure_optimizers(self):
        return torch.optim.Adam(model.fc.parameters(), lr=0.001)




model_pl = RetinaClassifier(model)

trainer = L.Trainer(max_epochs=20)
trainer.fit(model_pl, train_dataloader, val_dataloader)