Wednesday, 8 April 2020

Deep Learning Analysis of COVID-19 lung X-Rays using MATLAB: Part 1

Update: see Part 5 where the grad-CAM results of Part 4 are used to train another suite of networks to help choose between all the lung X-ray classifiers presented in Part 3

UPDATE: see Part 4 where I've performed a grad-CAM analysis on all the trained networks from Part 3, in the theme of Part 2.


UPDATE: see Part 3 where I've now compared the (Transfer Learning) performance of all 19 neural network types available via MATLAB R2020a on the lung X-ray analysis i.e., extending beyond just googlenet covered here 

UPDATE: see Part 2 where I have now performed a Class Activation Map study to successfully counter the caveats in this post whereby I had a concern that the trained networks may be utilising extraneous artefacts embedded in the X-ray images (e.g., text etc) to exaggerate their predictive performance

*** DISCLAIMER ***

I have no medical training. Nothing presented here should be considered in any way as informative from a medical point-of-view. This is simply an exercise in image analysis via Deep Learning using MATLAB, with lung X-rays as a topical example in these times of COVID-19. 

INTRODUCTION

Recent examples of lung X-ray image classification via deep learning have utilised TensorFlow. The most comprehensive approach I could find so far is described here with detail here  (which I'll refer to as COVID-Net). I just wanted to try something similar in MATLAB since that is my tool of choice in my day job where I use MATLAB for various Artificial Intelligence / Machine Learning investigations and also for other side-projects such as aviation weather forecasting.

APPROACH

My goal was to use the underlying chest X-ray image dataset from COVID-Net to train a deep neural network via the technique of transfer learning, just to see how well the resulting classifiers would perform. All analysis is performed using MATLAB, and code snippets are provided which may be useful to others.

SAMPLE IMAGES

Before getting started with the analysis, here are some sample images from which the training and testing will be performed, just to give an idea of the challenge for the neural networks in classifying between the various alternatives.

Healthy

Bacterial Pneumonia
Viral Pneumonia (not COVID-19)
COVID-19 Pneumonia


EXAMPLE 1: "YES / NO" Classification of Pneumonia

Data Preparation

This first example addresses the task of training the network to classify whether a given X-ray belongs to a normal (healthy) patient or one suffering from pneumonia, irrespective of the type of pneumonia (i.e., bacterial or viral, etc).

The dataset has 752 normal images, and 4558 with pneumonia (across all types). To provide a balanced set across the two target classes ("yes" for pneumonia and "no" for healthy), I used only 752 images from each (all of the "no" class, and randomly selected from 4556 of the "yes" class). From each class, I used (randomly selected) 85% i.e., 640 for training and 15% i.e, 112 for validation. The line of code which does this in MATLAB is:


[trainingImages,validationImages,holdoutImages] = splitEachLabel(images,640,112,'randomized'); 

where images refers to an imageDatastore object initialised on a master folder of images sorted into two subfolders containing the images from each (YES and NO) class, created via the following line of code:


images = imageDatastore(sortedPath,'IncludeSubfolders',true,'LabelSource','foldernames');


where sortedPath is the variable containing the name of the master folder. Note: I performed the sorting offline based simply on the substring "NORMAL" appearing in the filename (to define the "NO" class) assuming that all filenames without "NORMAL" were in the "YES" class. Here's a code snippet showing how to do this in MATLAB:


allImages = imageDatastore('\covid\data\train','IncludeSubfolders',false); 

yesPath='\covid\sorted\yesno\yes\'; 
noPath='\covid\sorted\yesno\no\'; 

for i=1:length(allImages.Files)  
  [~,name,ext] = fileparts(char(allImages.Files(i))); 
 if contains(name, 'NORMAL') 
    destfolder=noPath; 
 else 
    destfolder=yesPath; 
 end 

 destfile=[destfolder name ext]; 
 copyfile(char(allImages.Files(i)),destfile); 
end


Next, I created an imageDataAugmenter with random translation shifts of +/-3 pixels and rotational shifts of +/- 10 degrees via the following line of code:


imageAugmenter = imageDataAugmenter( ... 'RandRotation',[-10,10], ... 'RandXTranslation',[-3 3], ... 'RandYTranslation',[-3 3]);

Applying this to the training set with the inclusion of  the 'gray2rgb' colour pre-processor (so that all the different types of image files e.g., jpeg, png, etc., can be imported via the same dataset without error) gives the actual training set used (denoted trainingImages_):


trainingImages_=augmentedImageDatastore(outputSize,trainingImages,'ColorPreprocessing','gray2rgb','DataAugmentation',imageAugmenter);

...and similarly for the validation set but without the augmentation:

validationImages_=augmentedImageDatastore(outputSize,validationImages,'ColorPreprocessing','gray2rgb');

Note that outputSize is set as follows since I'm using googlenet in the transfer learning:

outputSize=[224 224 3]; %FOR GOOGLENET

Network Preparation

Here's the code I used to prepare the pre-trained googlenet for transfer learning (i.e., by replacing the final few layers of the network):


net = googlenet; lgraph = layerGraph(net); 

%Replace final layers 

lgraph = removeLayers(lgraph, {'loss3-classifier','prob','output'}); 
numClasses = numel(categories(trainingImages.Labels)); 

newLayers = [ fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10) softmaxLayer('Name','softmax') classificationLayer('Name','classoutput')]; 

lgraph = addLayers(lgraph,newLayers); 

%Connect the last transferred layer remaining in the network %('pool5-drop_7x7_s1') to the new layers. 

lgraph = connectLayers(lgraph,'pool5-drop_7x7_s1','fc');

The figure below shows the last few layers of the network with the above replacements:

Transfer Learning preparation: replacement of last few layers of googlenet















which was displayed using the following code:


figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]); 
plot(lgraph) 
ylim([0,10]);


The training options are set as follows (by trial-and-error mostly!)


miniBatchSize = 10; 
MaxEpochs=12; 
numIterationsPerEpoch = floor(numel(trainingImages.Labels)/miniBatchSize); 

options = trainingOptions('sgdm',...
 'MiniBatchSize',miniBatchSize,...
 'MaxEpochs',MaxEpochs,...
 'InitialLearnRate',1e-4,... 
'Verbose',false,... 
'Plots','training-progress',... 
'ValidationData',validationImages_,... 
'ValidationFrequency',numIterationsPerEpoch,... 
'ValidationPatience',Inf);

...and the actual training is performed via the following line of code:

netTransfer = trainNetwork(trainingImages_,lgraph,options);

Results

Since my options include 'Plots','training-progress',  the following chart is presented (in real-time as training progresses):


Example 1: Training Convergence




















It can be seen that the training converges nicely, though given the slight downturn in the validation accuracy at the end of the run (refer to the black dots in the upper blue curve) and the slight upturn in loss ( black dots in lower orange curve), there has been a small degree of overfitting. The training should therefore have been stopped slightly earlier, ideally.

For assessing the classification performance, the error statistics applied to the validation set are conveniently presented by way of the Confusion Matrix, as follows:


predictedLabelsValidation = classify(netTransfer,validationImages_); plotconfusion(validationImages.Labels,predictedLabelsValidation);

which produces the following chart:

Example 1: Confusion Matrix for validation set


























The performance is quite reasonable, with an average accuracy of  90.2% (true negative: 85.7%, and true positive 95.9%; false negative 14.3%, and false positive 4.1%). Caveat: I have not checked if there is some spurious reason which makes the performance appear artificially better than it is. For example, an identifying text character etc., which may be present in (some or all of) the images which gives a definitive clue to the "yes" or "no" nature of the content such that the image classifier is actually -- and erroneously -- picking-up on this clue rather than identifying the actual lung state. I simply took the entire images "as is". A more rigorous analysis would need to check for such.


EXAMPLE 2: Classification "Bacterial" or "Viral" Pneumonia

Data Preparation

For this next task, I assume that we know that the patient is suffering from pneumonia, but want to train a network to determine whether the pneumonia is viral or bacterial. Again, this is a two-class problem where the classes are "bacteria" and "virus".  The training set is constructed by taking the training images from the "yes" bucket (i.e., the known pneumonia cases) from Example 1 and sorting them into "bacteria" and "virus". Again, I did this offline based on whether the filenames included the substrings "bacteria" or "streptococcus" and placing those in the "bacteria" subfolder, and all others in the "virus" subfolder. In this case, there was almost an equal number of images for each class, so the datasets were created as follows (using 2019 from each class, split 85%/15% train-validate, as before):

[trainingImages,validationImages,holdoutImages] = splitEachLabel(images,1717,302,'randomized');

Thereafter, the preparation steps mirror those for Example 1.

Network Preparation

Mirroring the steps followed for Example 1.

Results


The corresponding training convergence plot is shown below:

Example 2: Training Convergence



















Again, the convergence is good (though not as convincing as in Example 1). There is likewise a (slightly more pronounced) degree of over-fitting which could be eliminated by stopping earlier. The corresponding Confusion Matrix computed for the validation set is shown below:

Example 2: Confusion Matrix for validation set
























The performance is reasonable (though not as good as in Example 1) with an average classification accuracy of 78% (true bacteria: 79.2%, and true virus 76.8%; false bacteria 20.8%, and false virus 22%). Caveat: again, I have not checked if there are underlying clues in the images which exaggerate the performance: I simply took the entire images "as is".

EXAMPLE 3: Classification of COVID-19 or Other-Viral

Data Preparation

For this next task, I assume that we know the patient is suffering from some form of viral pneumonia but want to train a network to determine whether the pneumonia is COVID-19 rather than some other form (SARS, MERS, etc.).. Again, this is a two-class problem where the classes are "covid" and "other".  The training set is constructed by taking the training images from the "viral" bucket (i.e., the known viral pneumonia cases) from Example 2 and sorting them into "covid" and "other". Again, I did this offline based on whether the filenames included the substrings "covid" or "corona" and placing those in the "covid" subfolder, and all others in the "other" subfolder. In this case, there were only 76 covid images versus 2014 other viral, so the datasets were created as follows (using only 76 from each class, split 85%/15% train-validate, as before):

[trainingImages,validationImages,holdoutImages] = splitEachLabel(images,65,11,'randomized');

Thereafter, the preparation steps mirror those in the previous examples.

Network Preparation

Mirroring the steps from the previous examples except setting the maximum number of epochs to 7 rather than 12 in order to prevent overfitting due to the relatively small number of training images compared with the previous examples.

Results

The corresponding training convergence plot is shown below:

Example 3: Training Convergence



















Again, the convergence is good. The plot is far less dense than previous examples owing to the significantly reduced number of training images and reduced number of epochs. The corresponding Confusion Matrix computed for the validation set is shown below.

Example 3: Confusion Matrix for validation set
























The performance is good, albeit on a rather small validation set of only 22 images (11 covid, 11 other-viral) with an average classification accuracy of 95.5% (true covid: 100%, and true other-virus 91.7%; false covid 0%, and false other-virus 4.5%). Caveat: again, I have not checked if there are underlying clues in the images which exaggerate the performance: I simply took the entire images "as is".

EXAMPLE 4: Determine if COVID-19 pneumonia versus Healthy, Bacterial, or non-COVID viral pneumonia

Data Preparation

In this final task, the challenge for the neural network is the most demanding: namely, from a given lung X-ray, determine if the patient is healthy, has bacterial pneumonia, non-COVID-19 viral pneumonia, or COVID-19 pneumonia.  This is a four-class problem rather than all previous examples which were (simpler) two-class problems. The four classes are "healthy", "bacteria", "viral-other", and "covid". The training set is the entire basket of training images, but since there are only 76 covid images, then only 76 are used for each of the classes (65 for training, 11 for validation) in order to balance the dataset when training the neural network, as follows (i.e., same code as in Example 3):

[trainingImages,validationImages,holdoutImages] = splitEachLabel(images,65,11,'randomized');

Thereafter, the preparation steps mirror those in the previous examples.

Network Preparation

Mirroring the steps from the previous examples except setting the maximum number of epochs to 10 rather than 12 in order to prevent overfitting due to the relatively small number of training images.

Results

The corresponding training convergence plot is shown below:

Example 4: Training Convergence



















Again, the convergence is good thought quite as impressive as for some of the two-class examples.

The corresponding Confusion Matrix computed for the validation set is shown below.

Example 4: Confusion Matrix for validation set
























The performance is surprisingly good, albeit on a rather small validation set of only 44 images (11 for each class) with an average classification accuracy of 75%. Interestingly, all the COVID-19 examples are correctly identified as such. Moreover, there are no non-COVID-19 images which are erroneously mis-identified as COVID-19.  Caveat: again, I have not checked if there are underlying clues in the images which exaggerate the performance: I simply took the entire images "as is".

CONCLUSIONS


  • As an exercise in using MATLAB for Deep learning, this has been a definite success. I understand that TensorFlow is free of charge, and MATLAB is not. But if you do have access to MATLAB with the required Toolboxes for Deep Learning, it is a very powerful framework and easy to use compared with TensorFlow (my opinion).
  • As an exercise in identifying COVID-19 lung X-rays images from non-COVID-19 images, the approach of Transfer Learning from the pre-trained googlenet seems promising in terms of the observed performance. However, to do this properly, the caveat I have raised throughout (about possible clues being embedded in the images which exaggerate the performance of the classifiers) would need to be properly addressed. 
  • That said, as a next step it would be interesting to properly compare the custom network architecture developed in COVID-Net with the (much simpler) approach presented here (i.e., using only a slightly-modified googlenet). If anyone has the inclination to program the COVID-Net network in MATLAB, please let me know. I would like to help if I can.
  • I will re-visit the COVID-Net data resources and re-train the models whenever the COVID-19 image set becomes more extensive.
  • Keep safe.

No comments:

Post a comment