This program is designed to classify pre-processed image tiles (".png") from hematoxylin and eosin (H&E) stained slides of breast cancer tissues into one of the five PAM50 clinical subtypes (Basal, HER2E, Luminal A, Luminal B, Normal-like) using deep learning methods. The classification process uses a transfer learning ResNet-18 model.
To run the program, the following are required as input: a directory path containing "train" and "val" tile images organized in subtype subfolders (-d), and a test tile directory path containing test tile images organized into the same subtype folders (-t). Optionally, you can specify the number of training/validation epochs (-e, default=10) and/or batch size (-b,default=50).
Example command line prompt:
resnet_HE_v5.py -d "../PAM50_balanced/full_train/" -t "../PAM50_balanced/test/" [-e 10] [-b 50]
One run of this program will result in two figures. The first - "Resnetv5_TrainVal_Summary_PAM50_e[number of epochs]_b[batch size number].png" - will be line graphs tracking model loss and accuracy over the conducted epochs for both training and validation steps. The second - "Resnetv5_Test_Summary_PAM50_e[number of epochs]_b[batch size number].png" - will be a bar graph of test set tile classification results, with the number of tiles per subtype correctly classified on the model's "First Prediction" in dark blue, and the number of tiles per subtype correctly classified within the model's "Top 3 Predictions" in light blue. In addition to these two images, the following will also be printed to stdout: loss and accuracy metrics per training/validation epoch including total training time, full subtype prediction results of two test images per subtype, and overall subtype prediction accuracy metrics for test set tile images ("First Prediction" and "Top 3 Prediction").
Note: Images referenced below are available in the "Result_Images" folder of this main GitHub repository.
Being able to correctly classify a breast cancer patient's histological and clinical subtype can greatly improve the accuracy of therapeutic decisions, thereby increasing the possibility of better survival outcomes. These subtype classifications are routinely accomplished with manual H&E slide annotations by highly trainined pathologists, with one slide taking hours-days to process with the necessary high accuracy. Deep learning techniques are currently being developed to supplement these hand-annotations, using computational processing to greatly reduce annotation time on the side of the pathologist, while maintaining high accuracy.
Our initial model was based on the PyTorch transfer-learning ResNet-18 classification model example (https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html). Each layer of the network is "fine-tuned" to better learn from tile images in the training and validation datasets each epoch. A stochastic gradient descent (SGD) learning rate of 0.001 was specified by the tutorial code, and this was not adjusted for any of our models. Two final classes were used for this first iteration of the deep learning model: Invasive ductal carcinoma (IDC) and Invasive lobular carcinoma (ILC). The model was run at least 10 times while optimizing the hyperparameters of epoch number and sample size (batch size set to 20). The best verison of this initial 2-class model used roughly 3000 training and 1500 validation tiles from 20 whole slide images (WSI) per subtype class, and ran for 10 epochs with a batch size of 20. Runs with these parameters resulted in maximum training and validation accuracies of 98% (see Resnet_Summary_v2_hist_20BS.png) No test set was run with this model.
After observing high accuracy with 20 images running through the initial 2-class model, we made a second version of the model with 5 final classes corresponding to the PAM50 clinical subtypes (Basal, HER2E, Luminal A, Luminal B, Normal-like). This 5-class model was used to 1) see if model success metrics of loss and accuracy were affected by increase in number final classes and 2) experiment with batch size while continuing to gradually increase WSI and tile sample sizes. Runs using batch sizes of 2 (see Resnet_Summary_v2_PAM50_2BS.png), 20 (see Resnet_Summary_v2_PAM50_20BS.png), and 100 (see Resnet_Summary_v2_PAM50_100BS.png) tiles were compared on a unbalanced dataset of roughly 3500-12800 training and 790-3300 validation tiles per subtype folder. The optimal setup for this second-version 5 PAM50 class model used a batch size of 20, ran for 10 epochs, and resulted in maximum training and validation accuracies of over 98%. No test set was run with this model.
After observing high accuracy in the second 5-class version of the transfer-learning ResNet-18 model, we progressed to a third version. This third version used all ten combined histological and clinical PAM50 breast cancer subtype classes focused on in this project: IDC Basal, IDC HER2E, IDC Luminal A, IDC Luminal B, IDC Normal-like, ILC Basal, ILC HER2E, ILC Luminal A, ILC Luminal B, and ILC Normal-like. The entire manifest of TCGA images - 875 total WSI - was preprocessed and the resulting tiles were used for these (and future) model runs. Parameters of balanced/unbalanced subtype sample sizes, epoch number, and batch size were tested through several iterations of this third version model. Balancing the number of tile images per subtype class directory did not seem to greatly effect model accuracy, but is a standard practice for machine learning/deep learning models. As such, all future datasets run through the model were balanced by randomly downsampling the subtypes with more tiles down to the subtype with the fewest number of tiles, resulting in a subset of 1500 training and 400 validation tiles for each of the ten subtype classes. Upon realizing that this was not a sufficient number of tiles to produce a training accuracy greater than 65%, each of the tiles were copied with the following transformations: horizontal flip, vertical flip, and horizontal and vertical flips. This produced our final dataset size of 6000 training and 1600 validation tiles (see Resnet_Val_Summary_v5_10STB_1500TpST_50B_Btest.png, Resnet_Val_Summary_v5_10STB_6000TpST_50B_Btest.png).
With this sample size, extending the number of epochs to 20 resulted in minimal changes to accuracy after the initial 10 epochs (see Resnet_Val_Summary_v4_10STB_6000TpST_50B_20E.png), so 10 epochs were conducted for all future runs in this project. The optimal batch size was evaluated by comparing accuracy and runtime between three runs using batch sizes of 2 (see Resnet_Val_Summary_v4c_10STB_6000TpST_2BS.png), 50 (see Resnet_Val_Summary_v4_10STB_6000TpST_50BS.png), and 100 (see Resnet_Val_Summary_v4d_10STB_6000TpST_100B.png). The run using a batch size of 50 and the above specified parameters and sampled dataset resulted in maximum training and validation accuracies of greater than 80% (see Resnet_Val_Summary_v5_10STB_6000TpST_50BS.png). A test set of images, with 600 new tiles per subtype, was also run to test the model on "unseen" images. For these tiles, the model's top subtype prediction was correct 9-31% of the time for IDC subtypes, and less than 10% correct for the ILC subtypes (with the exception of ILC Luminal A at 26%). Taking into account the top 3 predictions the model makes per tile, these test accuracies increase to 18-64% for IDC subtypes, and 2-7% for ILC subtypes (again with the exception of ILC Luminal A at 74%) (see Resnet_Test_Summary_v5_10STB_6000TpST_50BS.png). Based on the test results, we conclude that the subtypes with more original WSIs - the five IDC subtypes and ILC Luminal A - are more consistently correctly identified by the model than the subtypes with fewer original WSI - the four remaining ILC subtypes.
To finish this project as a valid, usable proof-of-concept program, the final model - currently implemented in this program - reverts back to 5 PAM50 classes in the final classification layer, as this was the version of the model with the most final classes and the best training, validation, and test accuracies. Argparse functionality was added at this stage to accept user input for training/validation and testing directories, as well as the two hyperparameters we experimented with during this project, epoch number and batch size. When running on the balanced,preprocessed PAM50 dataset of all WSIs in the manifest file (42000 training, 11000 validation, 6600 testing tiles per subtype), the maximum training and validation accuracies are above 90% (see Resnetv5_TrainVal_Summary_PAM50_e10_b50.png). The testing accuracies per subtype of this trained model are slightly better then the 10-class version of the model, with top prediction accuracies between 6-48% and top 3 prediction accuracies between 43-92% (see Resnetv5_Test_Summary_PAM50_e10_b50.png). The Basal and Luminal A subtype test tiles are generally classified with higher accuracy (49% and 37% top prediction, 84% and 93% top 3 prediction, respectively), followed closely by HER2E (28% top prediction, 79% top 3 prediction). The remaining two PAM50 subtypes, Luminal B and Normal-like, are not as reliably classified by the model (7% and 8% top prediction, 57% and 44% top 3 prediction, respectively).
Possible next steps for improving the accuracy of this model could include the following:
- Exploring additional tile transformation methods beyond traditional flipping techniques
- Experimenting with learning rate adjustments and/or different learning optimization methods outside of SGD
- Expanding the WSI dataset for under-represented subtypes - either in the current 5-class model or the 10-class model - to reduce the degree of class balancing currently required
- Creating a version of this ResNet-18 based model using a base network with more convolutional layers (i.e. ResNet-34,-101,-1202, etc.)