Note

If running in Colab, think of changing the runtime type before starting, in order to have access to GPU ressources: Runtime->Change Runtime Type, then chose GPU for hardware accelerator.

0. Classification of Alzheimer’s disease diagnosis

The goal of this lab session is to train a network that will perform a binary classification between control participants and patients that are affected by Alzheimer’s disease. The input of the network is a neuroimaging modality: the T1 weighted MRI. In this project we use the Pytorch library.

import torch
import numpy as np
import pandas as pd
from torch import nn
from time import time
from os import path
from torchvision import transforms
import random
from copy import deepcopy

Database

In this session we use the images from a public research project: OASIS-1. Two labels exist in this dataset:

  • CN (Cognitively Normal) for healthy participants.

  • AD (Alzheimer’s Disease) for patients affected by Alzheimer’s disease.

The original images were preprocessed using Clinica: a software platform for clinical neuroimaging studies. Preprocessed images and other files are distributed in a tarball, run the following commands to download and extract them.

! wget --no-check-certificate --show-progress https://aramislab.paris.inria.fr/files/data/databases/DL4MI/OASIS-1-dataset_pt_new.tar.gz
--2022-01-20 18:25:26--  https://aramislab.paris.inria.fr/files/data/databases/DL4MI/OASIS-1-dataset_pt_new.tar.gz
Résolution de aramislab.paris.inria.fr (aramislab.paris.inria.fr)… 128.93.101.229
Connexion à aramislab.paris.inria.fr (aramislab.paris.inria.fr)|128.93.101.229|:443… connecté.
Avertissement : impossible de vérifier l’attribut aramislab.paris.inria.fr du certificat, émis par «CN=TERENA SSL CA 3,O=TERENA,L=Amsterdam,ST=Noord-Holland,C=NL» :
  Impossible de vérifier localement l’autorité de l’émetteur.
requête HTTP transmise, en attente de la réponse… 200 OK
Taille : 1387416064 (1,3G) [application/octet-stream]
Enregistre : «OASIS-1-dataset_pt_new.tar.gz»


          OASIS-1-d   0%[                    ]       0  --.-KB/s               
         OASIS-1-da   1%[                    ]  21,29M   102MB/s               
        OASIS-1-dat   3%[                    ]  44,66M   109MB/s               
       OASIS-1-data   5%[>                   ]  67,02M   110MB/s               
      OASIS-1-datas   6%[>                   ]  88,68M   110MB/s               
     OASIS-1-datase   8%[>                   ] 111,01M   110MB/s               
    OASIS-1-dataset  10%[=>                  ] 133,37M   110MB/s               
   OASIS-1-dataset_  11%[=>                  ] 151,55M   108MB/s               
  OASIS-1-dataset_p  12%[=>                  ] 169,51M   105MB/s               
 OASIS-1-dataset_pt  14%[=>                  ] 187,38M   104MB/s               
OASIS-1-dataset_pt_  15%[==>                 ] 204,07M   102MB/s               
ASIS-1-dataset_pt_n  16%[==>                 ] 220,66M  99,8MB/s               
SIS-1-dataset_pt_ne  17%[==>                 ] 226,04M  80,6MB/s               
IS-1-dataset_pt_new  18%[==>                 ] 242,96M  80,8MB/s    tps 13s    
S-1-dataset_pt_new.  19%[==>                 ] 259,84M  81,1MB/s    tps 13s    
-1-dataset_pt_new.t  20%[===>                ] 276,67M  81,3MB/s    tps 13s    
1-dataset_pt_new.ta  22%[===>                ] 293,48M  80,4MB/s    tps 13s    
-dataset_pt_new.tar  23%[===>                ] 310,37M  79,3MB/s    tps 13s    
dataset_pt_new.tar.  24%[===>                ] 325,92M  77,8MB/s    tps 12s    
ataset_pt_new.tar.g  26%[====>               ] 344,58M  75,8MB/s    tps 12s    
taset_pt_new.tar.gz  27%[====>               ] 360,82M  74,7MB/s    tps 12s    
aset_pt_new.tar.gz   28%[====>               ] 377,30M  73,6MB/s    tps 12s    
set_pt_new.tar.gz    29%[====>               ] 394,63M  71,9MB/s    tps 12s    
et_pt_new.tar.gz     31%[=====>              ] 411,88M  71,8MB/s    tps 11s    
t_pt_new.tar.gz      32%[=====>              ] 429,30M  72,0MB/s    tps 11s    
_pt_new.tar.gz       33%[=====>              ] 445,80M  71,2MB/s    tps 11s    
pt_new.tar.gz        35%[======>             ] 463,79M  71,9MB/s    tps 11s    
t_new.tar.gz         36%[======>             ] 480,77M  72,5MB/s    tps 11s    
_new.tar.gz          37%[======>             ] 497,13M  84,6MB/s    tps 10s    
new.tar.gz           38%[======>             ] 513,43M  84,2MB/s    tps 10s    
ew.tar.gz            40%[=======>            ] 529,69M  84,2MB/s    tps 10s    
w.tar.gz             41%[=======>            ] 546,10M  84,2MB/s    tps 10s    
.tar.gz              42%[=======>            ] 562,77M  84,1MB/s    tps 10s    
tar.gz               43%[=======>            ] 579,43M  84,4MB/s    tps 9s     
ar.gz                44%[=======>            ] 595,15M  83,6MB/s    tps 9s     
r.gz                 46%[========>           ] 611,57M  83,5MB/s    tps 9s     
.gz                  47%[========>           ] 631,33M  84,6MB/s    tps 9s     
gz                   49%[========>           ] 650,57M  85,4MB/s    tps 9s     
z                    50%[=========>          ] 667,07M  85,0MB/s    tps 8s     
                     51%[=========>          ] 685,30M  85,2MB/s    tps 8s     
                  O  53%[=========>          ] 703,44M  85,7MB/s    tps 8s     
                 OA  54%[=========>          ] 721,88M  85,9MB/s    tps 8s     
                OAS  55%[==========>         ] 740,10M  86,3MB/s    tps 8s     
               OASI  57%[==========>         ] 758,43M  86,8MB/s    tps 7s     
              OASIS  58%[==========>         ] 776,96M  87,5MB/s    tps 7s     
             OASIS-  59%[==========>         ] 793,52M  87,8MB/s    tps 7s     
            OASIS-1  61%[===========>        ] 810,18M  87,6MB/s    tps 7s     
           OASIS-1-  62%[===========>        ] 827,22M  88,0MB/s    tps 7s     
          OASIS-1-d  63%[===========>        ] 845,51M  88,6MB/s    tps 6s     
         OASIS-1-da  65%[============>       ] 864,51M  89,4MB/s    tps 6s     
        OASIS-1-dat  66%[============>       ] 881,94M  90,0MB/s    tps 6s     
       OASIS-1-data  67%[============>       ] 898,32M  88,9MB/s    tps 6s     
      OASIS-1-datas  69%[============>       ] 914,74M  88,1MB/s    tps 6s     
     OASIS-1-datase  70%[=============>      ] 932,19M  88,1MB/s    tps 5s     
    OASIS-1-dataset  71%[=============>      ] 950,02M  88,1MB/s    tps 5s     
   OASIS-1-dataset_  73%[=============>      ] 967,51M  88,0MB/s    tps 5s     
  OASIS-1-dataset_p  74%[=============>      ] 984,99M  87,8MB/s    tps 5s     
 OASIS-1-dataset_pt  75%[==============>     ]   1002M  87,3MB/s    tps 5s     
OASIS-1-dataset_pt_  76%[==============>     ]   1019M  86,8MB/s    tps 4s     
ASIS-1-dataset_pt_n  78%[==============>     ]   1,01G  86,2MB/s    tps 4s     
SIS-1-dataset_pt_ne  79%[==============>     ]   1,03G  86,3MB/s    tps 4s     
IS-1-dataset_pt_new  80%[===============>    ]   1,04G  86,4MB/s    tps 4s     
S-1-dataset_pt_new.  82%[===============>    ]   1,06G  86,5MB/s    tps 4s     
-1-dataset_pt_new.t  83%[===============>    ]   1,08G  86,3MB/s    tps 3s     
1-dataset_pt_new.ta  84%[===============>    ]   1,09G  85,6MB/s    tps 3s     
-dataset_pt_new.tar  85%[================>   ]   1,11G  85,1MB/s    tps 3s     
dataset_pt_new.tar.  87%[================>   ]   1,13G  85,7MB/s    tps 3s     
ataset_pt_new.tar.g  88%[================>   ]   1,15G  86,9MB/s    tps 3s     
taset_pt_new.tar.gz  90%[=================>  ]   1,17G  87,3MB/s    tps 2s     
aset_pt_new.tar.gz   91%[=================>  ]   1,19G  88,0MB/s    tps 2s     
set_pt_new.tar.gz    93%[=================>  ]   1,20G  88,5MB/s    tps 2s     
et_pt_new.tar.gz     94%[=================>  ]   1,22G  89,3MB/s    tps 2s     
t_pt_new.tar.gz      96%[==================> ]   1,24G  90,1MB/s    tps 2s     
_pt_new.tar.gz       97%[==================> ]   1,26G  90,5MB/s    tps 0s     
pt_new.tar.gz        98%[==================> ]   1,28G  91,0MB/s    tps 0s     
OASIS-1-dataset_pt_ 100%[===================>]   1,29G  91,4MB/s    ds 15s     

2022-01-20 18:25:41 (86,1 MB/s) - «OASIS-1-dataset_pt_new.tar.gz» enregistré [1387416064/1387416064]
! tar xf OASIS-1-dataset_pt_new.tar.gz -C ./

One crucial step before training a neural network is to check the dataset. Are the classes balanced? Are there biases in the dataset that may differentiate the labels?

Here we will focus on the demographics (age, sex and level of education) and two cognitive scores:

  • The MMS (Mini Mental State), rated between 0 (no correct answer) to 30 (healthy subject).

  • The CDR (Clinical Dementia Rating), that is null if the participant is non-demented and of 0.5, 1, 2 and 3 for very mild, mild, moderate and severe dementia, respectively.

Let’s explore the data:

# Load the complete dataset
OASIS_df = pd.read_csv(
    'OASIS-1_dataset/tsv_files/lab_1/OASIS_BIDS.tsv', sep='\t',
    usecols=['participant_id', 'session_id', 'alternative_id_1', 'sex',
             'education_level', 'age_bl', 'diagnosis_bl', 'laterality', 'MMS',
             'cdr_global', 'diagnosis']
)
# Show first items of the table
print(OASIS_df.head())
# First visual inspection
_ = OASIS_df.hist(figsize=(16, 8))
   participant_id session_id alternative_id_1 sex  education_level  age_bl  \
0  sub-OASIS10001    ses-M00    OAS1_0001_MR1   F              2.0      74   
1  sub-OASIS10002    ses-M00    OAS1_0002_MR1   F              4.0      55   
2  sub-OASIS10003    ses-M00    OAS1_0003_MR1   F              4.0      73   
3  sub-OASIS10004    ses-M00    OAS1_0004_MR1   M              NaN      28   
4  sub-OASIS10005    ses-M00    OAS1_0005_MR1   M              NaN      18   

  diagnosis_bl laterality   MMS  cdr_global diagnosis  
0           CN          R  29.0         0.0        CN  
1           CN          R  29.0         0.0        CN  
2           AD          R  27.0         0.5        AD  
3           CN          R  30.0         NaN        CN  
4           CN          R  30.0         NaN        CN  
../_images/classification_7_1.png

From these graphics, it’s possible to have an overview of the distribution of the data, for the numerical values. For example, the educational level is well distributed among the participants of the study. Also, most of the subjects are young (around 20 years old) and healthy (MMS score equals 30 and null CDR score).

The next cell will create (and run) a function (characteristics_table) that highlights the main features of the population in the dataset. We will use it later.

# Study the characteristics of the AD & CN populations (age, sex, MMS, cdr_global)
def characteristics_table(df, merged_df):
    """Creates a DataFrame that summarizes the characteristics of the DataFrame df"""
    diagnoses = np.unique(df.diagnosis.values)
    population_df = pd.DataFrame(index=diagnoses,
                                columns=['N', 'age', '%sexF', 'education',
                                         'MMS', 'CDR=0', 'CDR=0.5', 'CDR=1', 'CDR=2'])
    merged_df = merged_df.set_index(['participant_id', 'session_id'], drop=True)
    df = df.set_index(['participant_id', 'session_id'], drop=True)
    sub_merged_df = merged_df.loc[df.index]

    for diagnosis in population_df.index.values:
        diagnosis_df = sub_merged_df[df.diagnosis == diagnosis]
        population_df.loc[diagnosis, 'N'] = len(diagnosis_df)
        # Age
        mean_age = np.mean(diagnosis_df.age_bl)
        std_age = np.std(diagnosis_df.age_bl)
        population_df.loc[diagnosis, 'age'] = '%.1f ± %.1f' % (mean_age, std_age)
        # Sex
        population_df.loc[diagnosis, '%sexF'] = round((len(diagnosis_df[diagnosis_df.sex == 'F']) / len(diagnosis_df)) * 100, 1)
        # Education level
        mean_education_level = np.nanmean(diagnosis_df.education_level)
        std_education_level = np.nanstd(diagnosis_df.education_level)
        population_df.loc[diagnosis, 'education'] = '%.1f ± %.1f' % (mean_education_level, std_education_level)
        # MMS
        mean_MMS = np.mean(diagnosis_df.MMS)
        std_MMS = np.std(diagnosis_df.MMS)
        population_df.loc[diagnosis, 'MMS'] = '%.1f ± %.1f' % (mean_MMS, std_MMS)
        # CDR
        for value in ['0', '0.5', '1', '2']:
          population_df.loc[diagnosis, 'CDR=%s' % value] = len(diagnosis_df[diagnosis_df.cdr_global == float(value)])

    return population_df

population_df = characteristics_table(OASIS_df, OASIS_df)
population_df
N age %sexF education MMS CDR=0 CDR=0.5 CDR=1 CDR=2
AD 73 77.5 ± 7.4 63.0 2.7 ± 1.3 22.7 ± 3.6 0 45 26 2
CN 304 44.0 ± 23.3 62.2 3.5 ± 1.2 29.7 ± 0.6 124 0 0 0

Preprocessing

Theoretically, the main advantage of deep learning methods is to be able to work without extensive data preprocessing. However, as we have only a few images to train the network in this lab session, the preprocessing here is very extensive. More specifically, the images encountered:

  1. Non-linear registration.

  2. Segmentation of grey matter.

  3. Conversion to tensor format (.pt).

As mentioned above, to obtain the preprocessed images, we used some pipelines provided by Clinica and ClinicaDL in order to:

  1. Convert the original dataset to BIDS format (clinica convert oasis-2-bids).

  2. Get the non-linear registration and segmentation of grey mater (pipeline t1-volume).

  3. Obtain the preprocessed images in tensor format (tensor extraction using ClinicaDL, clinicadl extract).

The preprocessed images are store in the CAPS folder structure and all have the same size (121x145x121). You will find below a class called MRIDataset which allows easy browsing in the database.

from torch.utils.data import Dataset, DataLoader, sampler
from os import path

class MRIDataset(Dataset):

    def __init__(self, img_dir, data_df, transform=None):
        """
        Args:
            img_dir (str): path to the CAPS directory containing preprocessed images
            data_df (DataFrame): metadata of the population.
                Columns include participant_id, session_id and diagnosis).
            transform (callable): list of transforms applied on-the-fly, chained with torchvision.transforms.Compose.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.data_df = data_df
        self.label_code = {"AD": 1, "CN": 0}

        self.size = self[0]['image'].shape

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):

        diagnosis = self.data_df.loc[idx, 'diagnosis']
        label = self.label_code[diagnosis]

        participant_id = self.data_df.loc[idx, 'participant_id']
        session_id = self.data_df.loc[idx, 'session_id']
        filename = 'subjects/' + participant_id + '/' + session_id + '/' + \
          'deeplearning_prepare_data/image_based/custom/' + \
          participant_id + '_' + session_id + \
          '_T1w_segm-graymatter_space-Ixi549Space_modulated-off_probability.pt'

        image = torch.load(path.join(self.img_dir, filename))

        if self.transform:
            image = self.transform(image)

        sample = {'image': image, 'label': label,
                  'participant_id': participant_id,
                  'session_id': session_id}
        return sample

    def train(self):
        self.transform.train()

    def eval(self):
        self.transform.eval()

To facilitate the training and avoid overfitting due to the limited amount of data, the network won’t use the full image but only a part of the image (size 30x40x30) centered on a specific neuroanatomical region: the hippocampus (HC). This structure is known to be linked to memory, and is atrophied in the majority of cases of Alzheimer’s disease patients.

To improve the training and reduce overfitting, a random shift was added to the cropping function. This means that the bounding box around the hippocampus may be shifted by a limited amount of voxels in each of the three directions.

class CropLeftHC(object):
    """Crops the left hippocampus of a MRI non-linearly registered to MNI"""
    def __init__(self, random_shift=0):
        self.random_shift = random_shift
        self.train_mode = True
    def __call__(self, img):
        if self.train_mode:
            x = random.randint(-self.random_shift, self.random_shift)
            y = random.randint(-self.random_shift, self.random_shift)
            z = random.randint(-self.random_shift, self.random_shift)
        else:
            x, y, z = 0, 0, 0
        return img[:, 25 + x:55 + x,
                   50 + y:90 + y,
                   27 + z:57 + z].clone()

    def train(self):
        self.train_mode = True

    def eval(self):
        self.train_mode = False

class CropRightHC(object):
    """Crops the right hippocampus of a MRI non-linearly registered to MNI"""
    def __init__(self, random_shift=0):
        self.random_shift = random_shift
        self.train_mode = True
    def __call__(self, img):
        if self.train_mode:
            x = random.randint(-self.random_shift, self.random_shift)
            y = random.randint(-self.random_shift, self.random_shift)
            z = random.randint(-self.random_shift, self.random_shift)
        else:
            x, y, z = 0, 0, 0
        return img[:, 65 + x:95 + x,
                   50 + y:90 + y,
                   27 + z:57 + z].clone()

    def train(self):
        self.train_mode = True

    def eval(self):
        self.train_mode = False

Visualization

Here we visualize the raw, preprocessed and cropped data.

import matplotlib.pyplot as plt
import nibabel as nib
from scipy.ndimage import rotate

subject = 'sub-OASIS10003'
preprocessed_pt = torch.load(f'OASIS-1_dataset/CAPS/subjects/{subject}/ses-M00/' +
                    f'deeplearning_prepare_data/image_based/custom/{subject}_ses-M00_' +
                    'T1w_segm-graymatter_space-Ixi549Space_modulated-off_' +
                    'probability.pt')
raw_nii = nib.load(f'OASIS-1_dataset/raw/{subject}_ses-M00_T1w.nii.gz')

raw_np = raw_nii.get_fdata()

def show_slices(slices):
    """ Function to display a row of image slices """
    fig, axes = plt.subplots(1, len(slices))
    for i, slice in enumerate(slices):
        axes[i].imshow(slice.T, cmap="gray", origin="lower")

slice_0 = raw_np[:, :, 78]
slice_1 = raw_np[122, :, :]
slice_2 = raw_np[:, 173, :]
show_slices([slice_0, rotate(slice_1, 90), rotate(slice_2, 90)])
plt.suptitle(f'Slices of raw image of subject {subject}')
plt.show()

slice_0 = preprocessed_pt[0, 60, :, :]
slice_1 = preprocessed_pt[0, :, 72, :]
slice_2 = preprocessed_pt[0, :, :, 60]
show_slices([slice_0, slice_1, slice_2])
plt.suptitle(f'Center slices of preprocessed image of subject {subject}')
plt.show()

leftHC_pt = CropLeftHC()(preprocessed_pt)
slice_0 = leftHC_pt[0, 15, :, :]
slice_1 = leftHC_pt[0, :, 20, :]
slice_2 = leftHC_pt[0, :, :, 15]
show_slices([slice_0, slice_1, slice_2])
plt.suptitle(f'Center slices of left HC of subject {subject}')
plt.show()
../_images/classification_15_0.png ../_images/classification_15_1.png ../_images/classification_15_2.png

1. Cross-validation

In order to choose hyperparameters the set of images is divided into a training set (80%) and a validation set (20%). The data split was performed in order to ensure a similar distribution of diagnosis, age and sex between the subjects of the training set and the subjects of the validation set. Moreover the MMS distribution of each class is preserved.

train_df = pd.read_csv('OASIS-1_dataset/tsv_files/lab_1/train.tsv', sep='\t')
valid_df = pd.read_csv('OASIS-1_dataset/tsv_files/lab_1/validation.tsv', sep='\t')

train_population_df = characteristics_table(train_df, OASIS_df)
valid_population_df = characteristics_table(valid_df, OASIS_df)

print(f"Train dataset:\n {train_population_df}\n")
print(f"Validation dataset:\n {valid_population_df}")
Train dataset:
       N          age %sexF  education         MMS CDR=0 CDR=0.5 CDR=1 CDR=2
AD   58   77.4 ± 7.5  69.0  2.8 ± 1.4  22.6 ± 3.6     0      37    19     2
CN  242  43.4 ± 23.5  62.0  3.6 ± 1.2  29.8 ± 0.5    97       0     0     0

Validation dataset:
      N          age %sexF  education         MMS CDR=0 CDR=0.5 CDR=1 CDR=2
AD  15   78.2 ± 6.6  40.0  2.5 ± 1.0  22.9 ± 3.6     0       8     7     0
CN  62  46.3 ± 22.6  62.9  3.4 ± 1.3  29.6 ± 0.7    27       0     0     0

2. Model

We propose here to design a convolutional neural network that takes for input a patch centered on the left hippocampus of size 30x40x30. The architecture of the network was found using a Random Search on architecture + optimization hyperparameters.

Reminder on CNN layers

In a CNN everything is called a layer though the operations layers perform are very different. You will find below a summary of the different operations that may be performed in a CNN.

Feature maps

The outputs of the layers in a convolutional network are called feature maps. Their size is written with the format:

n_channels @ dim1 x dim2 x dim3

For a 3D CNN the dimension of the feature maps is actually 5D as the first dimension is the batch size. This dimension is added by the DataLoader of Pytorch which stacks the 4D tensors computed by a Dataset.

img_dir = path.join('OASIS-1_dataset', 'CAPS')
batch_size=4

example_dataset = MRIDataset(img_dir, OASIS_df, transform=CropLeftHC())
example_dataloader = DataLoader(example_dataset, batch_size=batch_size, drop_last=True)
for data in example_dataloader:
    pass

print(f"Shape of Dataset output:\n {example_dataset[0]['image'].shape}\n")

print(f"Shape of DataLoader output:\n {data['image'].shape}")
Shape of Dataset output:
 torch.Size([1, 30, 40, 30])

Shape of DataLoader output:
 torch.Size([4, 1, 30, 40, 30])

Convolutions (nn.Conv3d)

The main arguments of this layer are the input channels, the output channels (number of filters trained) and the size of the filter (or kernel). If an integer k is given the kernel will be a cube of size k. It is possible to construct rectangular kernels by entering a tuple (but this is very rare).

You will find below an illustration of how a single filter produces its output feature map by parsing the one feature map. The size of the output feature map produced depends of the convolution parameters and can be computed with the following formula:

\(O_i = \frac{I_i-k+2P}{S} + 1\)

  • \(O_i\) the size of the output along the ith dimension

  • \(I_i\) the size of the input along the ith dimension

  • \(k\) the size of the kernel

  • \(P\) the padding value

  • \(S\) the stride value

In the following example \(\frac{5-3+2*0}{1}+1 = 3\)

To be able to parse all the feature maps of the input, one filter is actually a 4D tensor of size (input_channels, k, k, k). The ensemble of all the filters included in one convolutional layer is then a 5D tensor stacking all the filters of size (output_channels, input_channels, k, k, k).

Each filter is also associated to one bias value that is a scalar added to all the feature maps it produces. Then the bias is a 1D vector of size output_channels.

from torch import nn

conv_layer = nn.Conv3d(8, 16, 3)
print('Weights shape\n', conv_layer.weight.shape)
print()
print('Bias shape\n', conv_layer.bias.shape)
Weights shape
 torch.Size([16, 8, 3, 3, 3])

Bias shape
 torch.Size([16])

Batch Normalization (nn.BatchNorm3d)

Learns to normalize feature maps according to (Ioffe & Szegedy, 2015). The following formula is applied on each feature map \(FM_i\):

\(FM^{normalized}_i = \frac{FM_i - mean(FM_i)}{\sqrt{var(FM_i) + \epsilon}} * \gamma_i + \beta_i\)

  • \(\epsilon\) is a hyperparameter of the layer (default=1e-05)

  • \(\gamma_i\) is the value of the scale for the ith channel (learnable parameter)

  • \(\beta_i\) is the value of the shift for the ith channel (learnable parameter)

This layer does not have the same behaviour during training and evaluation, this is why it is needed to put the model in evaluation mode in the test function with the command .eval()

batch_layer = nn.BatchNorm3d(16)
print('Gamma value\n', batch_layer.state_dict()['weight'].shape)
print()
print('Beta value\n', batch_layer.state_dict()['bias'].shape)
Gamma value
 torch.Size([16])

Beta value
 torch.Size([16])

Activation function (nn.LeakyReLU)

In order to introduce non-linearity in the model, an activation function is introduced after the convolutions. It is applied on all intensities independently.

The graph of the Leaky ReLU is displayed below, \(\alpha\) being a hyperparameter of the layer (default=0.01):

Leaky ReLU graph

Pooling function (PadMaxPool3d)

The structure of the pooling layer is very similar to the convolutional layer: a kernel is passing through the input with a defined size and stride. However there is no learnable parameters in this layer, the kernel outputing the maximum value of the part of the feature map it covers.

Here is an example in 2D of the standard layer of pytorch nn.MaxPool2d:

nn.MaxPool2d behaviour

We can observe that the last column may not be used depending on the size of the kernel/input and stride value.

This is why the custom module PadMaxPool was defined to pad the input in order to exploit information from the whole feature map.

class PadMaxPool3d(nn.Module):
    """A MaxPooling module which deals with odd sizes with padding"""
    def __init__(self, kernel_size, stride, return_indices=False, return_pad=False):
        super(PadMaxPool3d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.pool = nn.MaxPool3d(kernel_size, stride, return_indices=return_indices)
        self.pad = nn.ConstantPad3d(padding=0, value=0)
        self.return_indices = return_indices
        self.return_pad = return_pad

    def set_new_return(self, return_indices=True, return_pad=True):
        self.return_indices = return_indices
        self.return_pad = return_pad
        self.pool.return_indices = return_indices

    def forward(self, f_maps):
        coords = [self.stride - f_maps.size(i + 2) % self.stride for i in range(3)]
        for i, coord in enumerate(coords):
            if coord == self.stride:
                coords[i] = 0

        self.pad.padding = (coords[2], 0, coords[1], 0, coords[0], 0)

        if self.return_indices:
            output, indices = self.pool(self.pad(f_maps))

            if self.return_pad:
                return output, indices, (coords[2], 0, coords[1], 0, coords[0], 0)
            else:
                return output, indices

        else:
            output = self.pool(self.pad(f_maps))

            if self.return_pad:
                return output, (coords[2], 0, coords[1], 0, coords[0], 0)
            else:
                return output

Here is an illustration of PadMaxPool behaviour. If the number of columns is odd, a column is added to avoid losing data:

Similarly, the formula to find the size of the output feature map is:

\(O_i = ceil(\frac{I_i-k+2P}{S}) + 1\)

Dropout (nn.Dropout)

The aim of a dropout layer is to replace a fixed proportion of the input values by 0 during training only.

This layer does not have the same behaviour during training and evaluation, this is why it is needed to put the model in evaluation mode in the test function with the command .eval()

dropout = nn.Dropout(0.5)
input_tensor = torch.rand(10)
output_tensor = dropout(input_tensor)
print("Input \n", input_tensor)
print()
print("Output \n", output_tensor)
Input 
 tensor([0.1311, 0.5546, 0.1405, 0.8767, 0.7725, 0.7001, 0.5156, 0.0886, 0.7068,
        0.1302])

Output 
 tensor([0.0000, 1.1091, 0.2810, 0.0000, 1.5450, 0.0000, 1.0311, 0.1773, 1.4136,
        0.0000])

Fully-Connected Layers (nn.Linear)

The fully connected layers take as input 2D vectors of size (batch_size, N). They have two mandatory arguments, the number of values per batch of the input and the number of values per batch of the output.

Each output neuron in a FC layer is a linear combination of the inputs + a bias.

fc = nn.Linear(16, 2)
print("Weights shape \n", fc.weight.shape)
print()
print("Bias shape \n", fc.bias.shape)
Weights shape 
 torch.Size([2, 16])

Bias shape 
 torch.Size([2])

TODO Network design

Construct here the network corresponding to the scheme and the following description:

Scheme of the network

The network includes 3 convolutional blocks composed by a convolutional layer (kernel size = 3, padding = 1, stride = 1), a batch normalization, a LeakyReLU activation and a MaxPooling layer. The 3 successive layers include respectively 8, 16 and 32 filters.

Then, the feature maps array is flattened in a 1D array to enter a fully-connected layer. Between the convolutional and the fully-connected layers, a dropout layer with a dropout rate of 0.5 is inserted.

# To complete
class CustomNetwork(nn.Module):

    def __init__(self):
        super(CustomNetwork, self).__init__()
        self.convolutions = nn.Sequential(
            nn.Conv3d(1, 8, 3, padding=1),
            # Size 8@30x40x30
            nn.BatchNorm3d(8),
            nn.LeakyReLU(),
            PadMaxPool3d(2, 2),
            # Size 8@15x20x15

            nn.Conv3d(8, 16, 3, padding=1),
            # Size 16@15x20x15
            nn.BatchNorm3d(16),
            nn.LeakyReLU(),
            PadMaxPool3d(2, 2),
            # Size 16@8x10x8)

            nn.Conv3d(16, 32, 3, padding=1),
            # Size 32@8x10x8
            nn.BatchNorm3d(32),
            nn.LeakyReLU(),
            PadMaxPool3d(2, 2),
            # Size 32@4x5x4

        )

        self.linear = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(32 * 4 * 5 * 4, 2)

        )

    def forward(self, x):
        x = self.convolutions(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

3. Train & Test

Complete the train method in order to iteratively update the weights of the network.

Here the model leading to the lowest loss on the training set at the end of an epoch is returned, however we could choose instead the model leading to the highest balanced accuracy, or the one obtained at the last iteration.

In many studies of deep learning the validation set is used during training to choose when the training should stop (early stopping) but also to retrieve the best model (model selection).

As we don’t have any test set to evaluate the final model selected in an unbiased way, we chose not to use the validation set in training in order to limit the bias of the validation set. However you can choose to implement an early stopping and / or model selection based on the validation set, but remember that even if your results on the validation set are better, that doesn’t mean that this would be the case on an independent test set.

def train(model, train_loader, criterion, optimizer, n_epochs):
    """
    Method used to train a CNN

    Args:
        model: (nn.Module) the neural network
        train_loader: (DataLoader) a DataLoader wrapping a MRIDataset
        criterion: (nn.Module) a method to compute the loss of a mini-batch of images
        optimizer: (torch.optim) an optimization algorithm
        n_epochs: (int) number of epochs performed during training

    Returns:
        best_model: (nn.Module) the trained neural network
    """
    best_model = deepcopy(model)
    train_best_loss = np.inf

    for epoch in range(n_epochs):
        model.train()
        train_loader.dataset.train()
        for i, data in enumerate(train_loader, 0):
            # Retrieve mini-batch and put data on GPU with .cuda()
            images, labels = data['image'].cuda(), data['label'].cuda()
            # Forward pass
            outputs = model(images)
            # Loss computation
            loss = criterion(outputs, labels)
            # Back-propagation (gradients computation)
            loss.backward()
            # Parameters update
            optimizer.step()
            # Erase previous gradients
            optimizer.zero_grad()

        _, train_metrics = test(model, train_loader, criterion)

        print(
            f"Epoch {epoch}: loss = {train_metrics['mean_loss']:.4f}, "
            f"balanced accuracy = {train_metrics['balanced_accuracy']:.4f}"
            )

        if train_metrics['mean_loss'] < train_best_loss:
            best_model = deepcopy(model)
            train_best_loss = train_metrics['mean_loss']

    return best_model

def test(model, data_loader, criterion):
    """
    Method used to test a CNN

    Args:
        model: (nn.Module) the neural network
        data_loader: (DataLoader) a DataLoader wrapping a MRIDataset
        criterion: (nn.Module) a method to compute the loss of a mini-batch of images

    Returns:
        results_df: (DataFrame) the label predicted for every subject
        results_metrics: (dict) a set of metrics
    """
    model.eval()
    data_loader.dataset.eval()
    columns = ["participant_id", "proba0", "proba1",
               "true_label", "predicted_label"]
    results_df = pd.DataFrame(columns=columns)
    total_loss = 0

    with torch.no_grad():
        for i, data in enumerate(data_loader, 0):
            images, labels = data['image'].cuda(), data['label'].cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            probs = nn.Softmax(dim=1)(outputs)
            _, predicted = torch.max(outputs.data, 1)

            for idx, sub in enumerate(data['participant_id']):
                row = [sub,
                       probs[idx, 0].item(), probs[idx, 1].item(),
                       labels[idx].item(), predicted[idx].item()]
                row_df = pd.DataFrame([row], columns=columns)
                results_df = pd.concat([results_df, row_df])

    results_metrics = compute_metrics(results_df.true_label.values, results_df.predicted_label.values)
    results_df.reset_index(inplace=True, drop=True)
    results_metrics['mean_loss'] = total_loss / len(data_loader.dataset)

    return results_df, results_metrics


def compute_metrics(ground_truth, prediction):
    """Computes the accuracy, sensitivity, specificity and balanced accuracy"""
    tp = np.sum((prediction == 1) & (ground_truth == 1))
    tn = np.sum((prediction == 0) & (ground_truth == 0))
    fp = np.sum((prediction == 1) & (ground_truth == 0))
    fn = np.sum((prediction == 0) & (ground_truth == 1))

    metrics_dict = dict()
    metrics_dict['accuracy'] = (tp + tn) / (tp + tn + fp + fn)

    # Sensitivity
    if tp + fn != 0:
        metrics_dict['sensitivity'] = tp / (tp + fn)
    else:
        metrics_dict['sensitivity'] = 0.0

    # Specificity
    if fp + tn != 0:
        metrics_dict['specificity'] = tn / (fp + tn)
    else:
        metrics_dict['specificity'] = 0.0

    metrics_dict['balanced_accuracy'] = (metrics_dict['sensitivity'] + metrics_dict['specificity']) / 2

    return metrics_dict

Train Classification with Left HC

Here we will train a first network that will learn to perform the binary classification AD vs CN on a cropped image around the left hippocampus.

All hyperparameters may have an influence, but one of the most influent is the learning rate that can lead to a poor convergence if it is too high or low. Try different learning rate between \(10 ^{-5}\) and \(10 ^{-3}\) and observe the differences of loss variations during training.

To increase the training speed you can also increase the batch size. But be careful, if the batch size becomes a non-negligible amount of the training set it may have a negative impact on loss convergence (Keskar et al, 2016).

Construction of dataset objects:

img_dir = path.join('OASIS-1_dataset', 'CAPS')
transform = CropLeftHC(2)

train_datasetLeftHC = MRIDataset(img_dir, train_df, transform=transform)
valid_datasetLeftHC = MRIDataset(img_dir, valid_df, transform=transform)

# Try different learning rates
learning_rate = 10**-4
n_epochs = 30
batch_size = 4

# Put the network on GPU
modelLeftHC = CustomNetwork().cuda()
train_loaderLeftHC = DataLoader(train_datasetLeftHC, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
# A high batch size improves test speed
valid_loaderLeftHC = DataLoader(valid_datasetLeftHC, batch_size=32, shuffle=False, num_workers=8, pin_memory=True)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(modelLeftHC.parameters(), learning_rate)

best_modelLeftHC = train(modelLeftHC, train_loaderLeftHC, criterion, optimizer, n_epochs)

valid_resultsLeftHC_df, valid_metricsLeftHC = test(best_modelLeftHC, valid_loaderLeftHC, criterion)
train_resultsLeftHC_df, train_metricsLeftHC = test(best_modelLeftHC, train_loaderLeftHC, criterion)
print(valid_metricsLeftHC)
print(train_metricsLeftHC)
Epoch 0: loss = 0.3441, balanced accuracy = 0.5259
Epoch 1: loss = 0.2748, balanced accuracy = 0.6966
Epoch 2: loss = 0.3429, balanced accuracy = 0.5410
Epoch 3: loss = 0.2412, balanced accuracy = 0.7228
Epoch 4: loss = 0.2342, balanced accuracy = 0.7490
Epoch 5: loss = 0.2253, balanced accuracy = 0.7900
Epoch 6: loss = 0.2225, balanced accuracy = 0.7642
Epoch 7: loss = 0.2727, balanced accuracy = 0.8887
Epoch 8: loss = 0.2144, balanced accuracy = 0.7728
Epoch 9: loss = 0.2088, balanced accuracy = 0.7880
Epoch 10: loss = 0.2322, balanced accuracy = 0.8818
Epoch 11: loss = 0.2088, balanced accuracy = 0.7814
Epoch 12: loss = 0.2208, balanced accuracy = 0.8814
Epoch 13: loss = 0.2137, balanced accuracy = 0.7724
Epoch 14: loss = 0.2268, balanced accuracy = 0.8794
Epoch 15: loss = 0.2042, balanced accuracy = 0.7790
Epoch 16: loss = 0.2111, balanced accuracy = 0.8942
Epoch 17: loss = 0.1904, balanced accuracy = 0.8118
Epoch 18: loss = 0.1884, balanced accuracy = 0.8159
Epoch 19: loss = 0.1905, balanced accuracy = 0.7942
Epoch 20: loss = 0.2029, balanced accuracy = 0.7614
Epoch 21: loss = 0.1819, balanced accuracy = 0.7942
Epoch 22: loss = 0.1795, balanced accuracy = 0.8159
Epoch 23: loss = 0.1930, balanced accuracy = 0.7807
Epoch 24: loss = 0.2110, balanced accuracy = 0.9159
Epoch 25: loss = 0.1759, balanced accuracy = 0.8331
Epoch 26: loss = 0.1744, balanced accuracy = 0.8697
Epoch 27: loss = 0.1858, balanced accuracy = 0.7893
Epoch 28: loss = 0.1685, balanced accuracy = 0.8307
Epoch 29: loss = 0.2021, balanced accuracy = 0.9352
{'accuracy': 0.8831168831168831, 'sensitivity': 0.6, 'specificity': 0.9516129032258065, 'balanced_accuracy': 0.7758064516129033, 'mean_loss': 0.30914231670367254}
{'accuracy': 0.9066666666666666, 'sensitivity': 0.7068965517241379, 'specificity': 0.9545454545454546, 'balanced_accuracy': 0.8307210031347962, 'mean_loss': 0.16847280997782946}

If you obtained about 0.85 or more of balanced accuracy, there may be something wrong… Are you absolutely sure that your dataset is unbiased?

If you didn’t remove the youngest subjects of OASIS, your dataset is biased as the AD and CN participants do not have the same age distribution. In practice people who come to the hospital for a diagnosis of Alzheimer’s disease all have about the same age (50 - 90). No one has Alzheimer’s disease at 20 ! Then you should check that the performance of the network is still good for the old population only.

Check the accuracy on old participants (age > 62 to match the minimum of AD age distribution)

valid_resultsLeftHC_df = valid_resultsLeftHC_df.merge(OASIS_df, how='left', on='participant_id', sort=False)
valid_resultsLeftHC_old_df = valid_resultsLeftHC_df[(valid_resultsLeftHC_df.age_bl >= 62)]
compute_metrics(valid_resultsLeftHC_old_df.true_label, valid_resultsLeftHC_old_df.predicted_label)
{'accuracy': 0.71875,
 'sensitivity': 0.6,
 'specificity': 0.8235294117647058,
 'balanced_accuracy': 0.7117647058823529}

If the accuracy on old participants is very different from the one you obtained before, this could mean that your network is inefficient on the target population (persons older than 60). You have to think again about your framework and eventually retrain your network…

Train Classification with Right HC

Another network can be trained on a cropped image around the right HC network. The same hyperparameters as before may be reused.

Construction of dataset objects

transform = CropRightHC(2)

train_datasetRightHC = MRIDataset(img_dir, train_df, transform=transform)
valid_datasetRightHC = MRIDataset(img_dir, valid_df, transform=transform)

learning_rate = 10**-4
n_epochs = 30
batch_size = 4

# Put the network on GPU
modelRightHC = CustomNetwork().cuda()
train_loaderRightHC = DataLoader(train_datasetRightHC, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loaderRightHC = DataLoader(valid_datasetRightHC, batch_size=32, shuffle=False, num_workers=8, pin_memory=True)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(modelRightHC.parameters(), learning_rate)

best_modelRightHC = train(modelRightHC, train_loaderRightHC, criterion, optimizer, n_epochs)

valid_resultsRightHC_df, valid_metricsRightHC = test(best_modelRightHC, valid_loaderRightHC, criterion)
train_resultsRightHC_df, train_metricsRightHC = test(best_modelRightHC, train_loaderRightHC, criterion)
print(valid_metricsRightHC)
print(train_metricsRightHC)
Epoch 0: loss = 0.3113, balanced accuracy = 0.5517
Epoch 1: loss = 0.2573, balanced accuracy = 0.7683
Epoch 2: loss = 0.2494, balanced accuracy = 0.7228
Epoch 3: loss = 0.2357, balanced accuracy = 0.7424
Epoch 4: loss = 0.2302, balanced accuracy = 0.7531
Epoch 5: loss = 0.2250, balanced accuracy = 0.7662
Epoch 6: loss = 0.2557, balanced accuracy = 0.9032
Epoch 7: loss = 0.2136, balanced accuracy = 0.8097
Epoch 8: loss = 0.2138, balanced accuracy = 0.8356
Epoch 9: loss = 0.2458, balanced accuracy = 0.9032
Epoch 10: loss = 0.2129, balanced accuracy = 0.7662
Epoch 11: loss = 0.2698, balanced accuracy = 0.9080
Epoch 12: loss = 0.2073, balanced accuracy = 0.8700
Epoch 13: loss = 0.2030, balanced accuracy = 0.8031
Epoch 14: loss = 0.2302, balanced accuracy = 0.8946
Epoch 15: loss = 0.2261, balanced accuracy = 0.8946
Epoch 16: loss = 0.2265, balanced accuracy = 0.7528
Epoch 17: loss = 0.2078, balanced accuracy = 0.7680
Epoch 18: loss = 0.1867, balanced accuracy = 0.8352
Epoch 19: loss = 0.2024, balanced accuracy = 0.9135
Epoch 20: loss = 0.1891, balanced accuracy = 0.9111
Epoch 21: loss = 0.1808, balanced accuracy = 0.8373
Epoch 22: loss = 0.1827, balanced accuracy = 0.8873
Epoch 23: loss = 0.1852, balanced accuracy = 0.8983
Epoch 24: loss = 0.1799, balanced accuracy = 0.8221
Epoch 25: loss = 0.1748, balanced accuracy = 0.8287
Epoch 26: loss = 0.1837, balanced accuracy = 0.9221
Epoch 27: loss = 0.2252, balanced accuracy = 0.7569
Epoch 28: loss = 0.1756, balanced accuracy = 0.8483
Epoch 29: loss = 0.1927, balanced accuracy = 0.9114
{'accuracy': 0.9090909090909091, 'sensitivity': 0.6666666666666666, 'specificity': 0.967741935483871, 'balanced_accuracy': 0.8172043010752688, 'mean_loss': 0.2644238243629406}
{'accuracy': 0.9033333333333333, 'sensitivity': 0.7068965517241379, 'specificity': 0.9504132231404959, 'balanced_accuracy': 0.828654887432317, 'mean_loss': 0.17483099497389049}

Soft voting

To increase the accuracy of our system the results of the two networks can be combined. Here we can give both hippocampi the same weight.

def softvoting(leftHC_df, rightHC_df):
    df1 = leftHC_df.set_index('participant_id', drop=True)
    df2 = rightHC_df.set_index('participant_id', drop=True)
    results_df = pd.DataFrame(index=df1.index.values,
                              columns=['true_label', 'predicted_label',
                                       'proba0', 'proba1'])
    results_df.true_label = df1.true_label
    # Compute predicted label and probabilities
    results_df.proba1 = 0.5 * df1.proba1 + 0.5 * df2.proba1
    results_df.proba0 = 0.5 * df1.proba0 + 0.5 * df2.proba0
    results_df.predicted_label = (0.5 * df1.proba1 + 0.5 * df2.proba1 > 0.5).astype(int)

    return results_df

valid_results = softvoting(valid_resultsLeftHC_df, valid_resultsRightHC_df)
valid_metrics = compute_metrics(valid_results.true_label, valid_results.predicted_label)
print(valid_metrics)
{'accuracy': 0.8961038961038961, 'sensitivity': 0.6666666666666666, 'specificity': 0.9516129032258065, 'balanced_accuracy': 0.8091397849462365}

Keep in mind that the validation set was used to set the hyperparameters (learning rate, architecture), then validation metrics are biased. To have unbiased results the entire framework should be evaluated on an independent set (test set).

4. Clustering on AD & CN populations

The classification results above were obtained in a supervised way: neurologists examine the participants of OASIS and gave a diagnosis depending on their clinical symptoms.

However, this label is often inaccurate (Beach et al, 2012). Then an unsupervised framework can be interesting to check what can be found in data without being biased by a noisy label.

Model

A convenient architecture to extract features from an image with deep learning is the autoencoder (AE). This architecture is made of two parts:

  • the encoder which learns to compress the image in a smaller vector, the code. It is composed of the same kind of operations than the convolutional part of the CNN seen before.

  • the decoder which learns to reconstruct the original image from the code learnt by the encoder. It is composed of the transposed version of the operations used in the encoder.

You will find below CropMaxUnpool3d the transposed version of PadMaxPool3d.

class CropMaxUnpool3d(nn.Module):
    def __init__(self, kernel_size, stride):
        super(CropMaxUnpool3d, self).__init__()
        self.unpool = nn.MaxUnpool3d(kernel_size, stride)

    def forward(self, f_maps, indices, padding=None):
        output = self.unpool(f_maps, indices)
        if padding is not None:
            x1 = padding[4]
            y1 = padding[2]
            z1 = padding[0]
            output = output[:, :, x1::, y1::, z1::]

        return output

To facilitate the reconstruction process, the pooling layers in the encoder return the position of the values that were the maximum. Hence the unpooling layer can replace the maximum values at the right place in the 2x2x2 sub-cube of the feature map. They also indicate if some zero padding was applied to the feature map so that the unpooling layer can correctly crop their output feature map.

class AutoEncoder(nn.Module):

    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Initial size (30, 40, 30)

        self.encoder = nn.Sequential(
            nn.Conv3d(1, 8, 3, padding=1),
            nn.BatchNorm3d(8),
            nn.LeakyReLU(),
            PadMaxPool3d(2, 2, return_indices=True, return_pad=True),
            # Size (15, 20, 15)

            nn.Conv3d(8, 16, 3, padding=1),
            nn.BatchNorm3d(16),
            nn.LeakyReLU(),
            PadMaxPool3d(2, 2, return_indices=True, return_pad=True),
            # Size (8, 10, 8)

            nn.Conv3d(16, 32, 3, padding=1),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(),
            PadMaxPool3d(2, 2, return_indices=True, return_pad=True),
            # Size (4, 5, 4)

            nn.Conv3d(32, 1, 1),
            # Size (4, 5, 4)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(1, 32, 1),
            # Size (4, 5, 4)

            CropMaxUnpool3d(2, 2),
            nn.ConvTranspose3d(32, 16, 3, padding=1),
            nn.BatchNorm3d(16),
            nn.LeakyReLU(),
            # Size (8, 10, 8)

            CropMaxUnpool3d(2, 2),
            nn.ConvTranspose3d(16, 8, 3, padding=1),
            nn.BatchNorm3d(8),
            nn.LeakyReLU(),
            # Size (15, 20, 15)

            CropMaxUnpool3d(2, 2),
            nn.ConvTranspose3d(8, 1, 3, padding=1),
            nn.BatchNorm3d(1),
            nn.Sigmoid()
            # Size (30, 40, 30)
        )

    def forward(self, x):
        indices_list = []
        pad_list = []
        for layer in self.encoder:
            if isinstance(layer, PadMaxPool3d):
                x, indices, pad = layer(x)
                indices_list.append(indices)
                pad_list.append(pad)
            else:
                x = layer(x)

        code = x.view(x.size(0), -1)
        for layer in self.decoder:
            if isinstance(layer, CropMaxUnpool3d):
                x = layer(x, indices_list.pop(), pad_list.pop())
            else:
                x = layer(x)

        return code, x

Train Autoencoder

The training function of the autoencoder is very similar to the training function of the CNN. The main difference is that the loss is not computed by comparing the output with the diagnosis values using the cross-entropy loss, but with the original image using for example the Mean Squared Error (MSE) loss.

def trainAE(model, train_loader, criterion, optimizer, n_epochs):
    """
    Method used to train an AutoEncoder

    Args:
        model: (nn.Module) the neural network
        train_loader: (DataLoader) a DataLoader wrapping a MRIDataset
        criterion: (nn.Module) a method to compute the loss of a mini-batch of images
        optimizer: (torch.optim) an optimization algorithm
        n_epochs: (int) number of epochs performed during training

    Returns:
        best_model: (nn.Module) the trained neural network.
    """
    best_model = deepcopy(model)
    train_best_loss = np.inf

    for epoch in range(n_epochs):
        model.train()
        train_loader.dataset.train()
        for i, data in enumerate(train_loader, 0):
            # ToDo
            # Complete the training function in a similar way
            # than for the CNN classification training.
            # Retrieve mini-batch
            images, labels = data['image'].cuda(), data['label'].cuda()
            # Forward pass + loss computation
            _, outputs = model((images))
            loss = criterion(outputs, images)
            # Back-propagation
            loss.backward()
            # Parameters update
            optimizer.step()
            # Erase previous gradients
            optimizer.zero_grad()

        mean_loss = testAE(model, train_loader, criterion)

        print(f'Epoch {epoch}: loss = {mean_loss:.6f}')

        if mean_loss < train_best_loss:
            best_model = deepcopy(model)
            train_best_loss = mean_loss

    return best_model


def testAE(model, data_loader, criterion):
    """
    Method used to test an AutoEncoder

    Args:
        model: (nn.Module) the neural network
        data_loader: (DataLoader) a DataLoader wrapping a MRIDataset
        criterion: (nn.Module) a method to compute the loss of a mini-batch of images

    Returns:
        results_df: (DataFrame) the label predicted for every subject
        results_metrics: (dict) a set of metrics
    """
    model.eval()
    data_loader.dataset.eval()
    total_loss = 0

    with torch.no_grad():
        for i, data in enumerate(data_loader, 0):
            images, labels = data['image'].cuda(), data['label'].cuda()
            _, outputs = model((images))
            loss = criterion(outputs, images)
            total_loss += loss.item()

    return total_loss / len(data_loader.dataset) / np.product(data_loader.dataset.size)
learning_rate = 10**-3
n_epochs = 30
batch_size = 4

AELeftHC = AutoEncoder().cuda()
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(AELeftHC.parameters(), learning_rate)

best_AELeftHC = trainAE(AELeftHC, train_loaderLeftHC, criterion, optimizer, n_epochs)
Epoch 0: loss = 0.092463
Epoch 1: loss = 0.057393
Epoch 2: loss = 0.046466
Epoch 3: loss = 0.039317
Epoch 4: loss = 0.035276
Epoch 5: loss = 0.031526
Epoch 6: loss = 0.028902
Epoch 7: loss = 0.026990
Epoch 8: loss = 0.025812
Epoch 9: loss = 0.023425
Epoch 10: loss = 0.022873
Epoch 11: loss = 0.020765
Epoch 12: loss = 0.019893
Epoch 13: loss = 0.019150
Epoch 14: loss = 0.018286
Epoch 15: loss = 0.017654
Epoch 16: loss = 0.017210
Epoch 17: loss = 0.017415
Epoch 18: loss = 0.016208
Epoch 19: loss = 0.015972
Epoch 20: loss = 0.015389
Epoch 21: loss = 0.015128
Epoch 22: loss = 0.014811
Epoch 23: loss = 0.014832
Epoch 24: loss = 0.014468
Epoch 25: loss = 0.014108
Epoch 26: loss = 0.014059
Epoch 27: loss = 0.013753
Epoch 28: loss = 0.013487
Epoch 29: loss = 0.013449

Visualization

The simplest way to check if the AE training went well is to visualize the output and compare it to the original image seen by the autoencoder.

import matplotlib.pyplot as plt
import nibabel as nib
from scipy.ndimage import rotate

subject = 'sub-OASIS10003'
preprocessed_pt = torch.load(f'OASIS-1_dataset/CAPS/subjects/{subject}/ses-M00/' +
                    'deeplearning_prepare_data/image_based/custom/' + subject +
                    '_ses-M00_'+
                    'T1w_segm-graymatter_space-Ixi549Space_modulated-off_' +
                    'probability.pt')
input_pt = CropLeftHC()(preprocessed_pt).unsqueeze(0).cuda()
_, output_pt = best_AELeftHC(input_pt)


slice_0 = input_pt[0, 0, 15, :, :].cpu()
slice_1 = input_pt[0, 0, :, 20, :].cpu()
slice_2 = input_pt[0, 0, :, :, 15].cpu()
show_slices([slice_0, slice_1, slice_2])
plt.suptitle(f'Center slices of the input image of subject {subject}')
plt.show()

slice_0 = output_pt[0, 0, 15, :, :].cpu().detach()
slice_1 = output_pt[0, 0, :, 20, :].cpu().detach()
slice_2 = output_pt[0, 0, :, :, 15].cpu().detach()
show_slices([slice_0, slice_1, slice_2])
plt.suptitle(f'Center slices of the output image of subject {subject}')
plt.show()
../_images/classification_58_0.png ../_images/classification_58_1.png

Clustering

Now that the AE extracted the most salient parts of the image in a smaller vector, the features obtained can be used for clustering.

Here we give an example with the Gaussian Mixture Model (GMM) of scikit-learn. To use it we first need to concat the features and the labels of all the subjects in two matrices X and Y. This is what is done in compute_dataset_features method.

def compute_dataset_features(data_loader, model):

    concat_codes = torch.Tensor().cuda()
    concat_labels = torch.LongTensor()
    concat_names = []

    for data in data_loader:
      image = data['image'].cuda()
      labels = data['label']
      names = data['participant_id']

      code, _ = model(image)
      concat_codes = torch.cat([concat_codes, code.squeeze(1)], 0)
      concat_labels = torch.cat([concat_labels, labels])
      concat_names = concat_names + names

    concat_codes_np = concat_codes.cpu().detach().numpy()
    concat_labels_np = concat_labels.numpy()
    concat_names = np.array(concat_names)[:, np.newaxis]

    return concat_codes_np, concat_labels_np, concat_names
# train_codes, train_labels, names = compute_dataset_features(train_loaderBothHC, best_AEBothHC)
train_codes, train_labels, names = compute_dataset_features(train_loaderLeftHC, best_AELeftHC)

Then the model will fit the training codes and build two clusters. The labels found in this unsupervised way can be compared to the true labels.

from sklearn import mixture
from sklearn.metrics import adjusted_rand_score

n_components = 2
model = mixture.GaussianMixture(n_components)
model.fit(train_codes)
train_predict = model.predict(train_codes)

metrics = compute_metrics(train_labels, train_predict)
ari = adjusted_rand_score(train_labels, train_predict)
print(f"Adjusted random index: {ari}")
Adjusted random index: 0.33869103831169145

The adjusted random index may not be very good, this could mean that the framework clustered another characteristic that the one you tried to target.

What is actually expected is that the clustering differenciation is made on the level of atrophy, which is mostly correlated to the age but also to the disease stage (we can model it with the MMS score).

data_np = np.concatenate([names, train_codes,
                          train_labels[:, np.newaxis],
                          train_predict[:, np.newaxis]], axis=1)
columns = ['feature %i' % i for i in range(train_codes.shape[1])]
columns = ['participant_id'] + columns + ['true_label', 'predicted_label']
data_df = pd.DataFrame(data_np, columns=columns).set_index('participant_id')

merged_df = data_df.merge(OASIS_df.set_index('participant_id'), how='inner', on='participant_id')

plt.title('Clustering values according to age and MMS score')
for component in range(n_components):
    predict_df = merged_df[merged_df.predicted_label == str(component)]
    plt.plot(predict_df['age_bl'], predict_df['MMS'], 'o', label=f"cluster {component}")
plt.legend()
plt.xlabel('age')
plt.ylabel('MMS')
plt.show()
../_images/classification_65_0.png

You can try to improve this clustering by adding the codes obtained on the right hippocampus, perform further dimension reduction or remove age effect like in (Moradi et al, 2015)