In this blog post, we discuss how to train a DenseNet style deep learning classifier, using Pytorch, for differentiating between different types of lymphoma cancer. This post and code are based on the post discussing segmentation using U-Net and is thus broken down into the same 4 components:
- Making training/testing databases,
- Training a model,
- Visualizing results in the validation set,
- Generating output.
This model focuses on using solely Python and freely available tools (i.e., no matlab).
This blog post assumes moderate knowledge of convolutional neural networks, depending on the readers background, our JPI paper may be sufficient, or a more thorough resource such as Andrew NG’s deep learning course.
Introduction to DenseNet
DenseNets were introduced here, so please refer there for a complete description. Instead, here we discuss only the high-level intuition needed to complete this tutorial.
DenseNets consist of multiple dense-blocks, which look like this:
These blocks are the workhorse of the densenet. Inside this block, the output from each kernel is concatenated with all subsequent features. When looking at x_4, one can notice that there are 4 other additional inputs being fed into it (i.e., yellow, purple, green, and red), 1 from each of the previous convolutional layers. Similarly, x_3 has 3 inputs, x_2 has 2, and x_0 has none as it is the first convolutional layer.
Multiple sets of these blocks are then sequentially applied, with a bottleneck layer in between them to form the entire network, which looks like this:
The authors of this approach claim “DenseNets exploit the potential of feature reuse, yielding condensed models that are easy to train and highly parameter efficient”.
Unpacking this, we can see the reasoning for these claims:
1) Directly connecting layers throughout the network helps to reduce the vanishing gradient problem
2) Features learned at the earlier layers, which likely contain important filters (for example edge detectors) can be reused in later networks directly as opposed to having to be relearned anew. This both (a) reduces the overall amount of feature redundancy (each layer doesn’t need to learn its own edge detector), resulting in fewer overall parameters, potentially less opportunities for overfitting, and (b) result in faster training times than e.g., ResNet (no additional computation required on “inherited” data in densenets, while resnets require additional operations)
These claims seem justified when looking at the comparison to other network architectures:
Making a database
Regardless of the desired model type used for the classifier, deep learning (DL) is typically best performed using some type of a database backend. This database need not be sophisticated (e.g., the LMDB commonly used by Caffe), nor must it be a true database (e.g., here we’ll discuss using PyTables which has an HDF5 backend). The benefits of using a database are numerous, but we can briefly list some of them here:
- As individual files, extracted image “patches” can often number into the thousands or millions in the case of DL. The access time can be very slow as modern operating systems are not typically tuned for efficient access of large numbers of files. Trying to do something as simple as “ls” in a directory in Windows/Linux with millions of files can cause a notable lag.
- Improved speed as a result of both reading off of disk, o/s level caching, and improved compression
- Improved reproducibility, coming back to a project later which has a database allows for much more consistent retraining of the model, especially when other files may have been lost (e.g., lists of training and testing sets)
- Better protection against data leakage. Although not required, I prefer creating separate databases (in this case files), for training, validation, and testing purposes. This can greatly reduce human error in making sure that only training images are read during training time, and that validation/test images are completely held out.
I’ve written a blog post previously about the how to get images into a hdf5 table using matlab, but this blog post uses only python and is more refined. We also focus on the types of data necessary for a DenseNet, storing (1) both the original image and (2) its associated label into the database:
The code to do this is available here.
Note that this code chops images into overlapping tiles, at a user specified stride, which is very fast based on python views. If your experiment focuses on only smaller annotated pieces of the image, this code would need to be adjusted (e.g., ROIs of localized disease presentation).
That said, some specific items worth pointing out in the code:
- As per cell 3, images are stored as unsigned 8-bit integers, and thus their values range from [0,255]
- Cell 4 assumes that each image represents a unique patient and thus naively splits the images into training and validation batches. If this is not the case, it should be addressed there by assigning appropriate files to each of the phase dictionary items. Always remember, training and validation should take place at a patient level (e., a patient is either in the training set or the testing set, but never both)
- We use a modest compression level of 6, which some experiments have shown to be a nice trade-off between time and size. This is easily modifiable in the code by changed the value of “complevel”
- Images are stored in the database in [IMAGE, NROW,NCOL,NCHANNEL] format. This allows for easier retrieval later. Likewise, the chunk size is set to the tile size such that it is compressed independently of the other items in the database, allowing for rapid data access.
- The class is determined by looking at the filename for one of the 3 specified labels. In this case, each of the classes is in its own unique directory, with the correct associated class name.
- The labels, filenames, and total number of instances of each class are stored in the database for downstream reference (the latter to allow for class weighting).
Training a model
Now that once we have the data ready, we’ll train the network. The Densenet architecture is provided by PyTorch in the torchvision package, in a very modular fashion. Thus the main components that we need to develop and discuss here is how to get our data in and out of the network.
One important practice which is commonly overlooked is to visually examine a sample of the input which will be going to the network, which we do in this cell:
We can see that the augmentations have been done properly (image looks “sane”). We can also note how the color augmentation has drastically changed the appearance of the image (left) from H&E to a greenish color space, in hopes of greater generalizability later. Here I’m using the default values as an example, but tuning these values will likely improve results if they’re tailored towards to specific test set of interest.
- In this example, we’ve used a reduced version of Densenet for prototyping (as defined by the parameters in the first cell). For production usage, these values should be tuned for optimal performance. In the 5th cell, one can see the default parameters. For an explination of the parameters, please refer to the both the code and manuscript links for Densenet provided above.
- In this configuration setting Num_workers>0 will now improve the speed of the network. This requires Pytorch version >.41, preferably v1. The improvement in speed is as a result of performing the augmentation (which is quite heavy) in parallel on the CPUs, while the GPU is operating on the current batch. Your mileage may vary.
- This code is heavily reused for both training and valuation through the notion of “phases”, in particular the cell labeled 134 contains a bit of intertwined code which sets appropriate backend functionality (e.g., enabling/disabling gradient computation)
- Note that the validation is step is very expensive in time, and should be enabled with care.
- We use tensorboardX to save real-time statistics from python for viewing in tensorboard, helping to visualize the training progress. Since classification can have more than 2 classes, we instead save the entire confusion matrix element size in the tensorboard log. One can understand that for any i,j where i==j, the confusion matrix count is a true positive. That is to say in our example, val/00, val/11, val/22 are indications of correct classification, with val/01 indicating a sample of class 0 being incorrectly classified as a sample of class 1.
- We save the best snapshot, overriding the last snapshot. Along with it, we save all information necessary to both load the network and to continue training later on
- The bottom of the notebook shows how to both visualize individual kernels and to visualize activations. Note that to be efficient pytorch does not keep activations in memory after the network is done computing. Thus it is impossible to retrieve them after the model does its prediction. As a result, we create a hook which saves the activations we’re interested at when the layer of interest is encountered.
Visualizing results in the validation set
Since we’re consistently saving the best model as the classifier is training, we can interrogate the results on the validation set easily while the network itself is training. This is best done if 2 GPUs are available, so that the main GPU can continue training while the second GPU can generate the output. If the network is small enough and the memory of a singular GPU large enough, both processes can be done using the same GPU.
A minimal example of how to do this is provided here. Note that there is a very large code overlap with the training code, except pruned down solely to generate output
- Augmentation in the cell labeled 18: in the case of using this for output generation, we want to use the original images since they will give a better sense of the expected output when used on the rest of the dataset, as a result, we disable all unnecessary augmentation. The only component that remains here is the randomcrop, to ensure that regardless of the size of the image in the database, we extract an appropriately sized patch
- We can see the output has 2 components. The numerical components show the ground truth class versus the predicted class, as well as the raw deep learning output (i.e., pre argmax). Additionally, we can see the input image after and before augmentation.
- At the bottom of the output, we can see the confusion matrix for this particular subset, with the accuracy (here shown to be 100% on 4 examples).
Lastly, we’d likely want to apply the model to generate output. For that, I’ve written a prototype command line python script available here.
Given the other files, the code here should be rather self-explanatory when it comes to generating the output.
The output is generated in a naïve fashion, where a large image is chopped into overlapping tiles, those tiles fed as a batch to the model.
The output of interest is can be found near the bottom of the code, left for the user to decide what and how to save:
- output contains the ntile x nclass DL output
- tileclass is a ntile x 1 vector indicating the predicted class of the tile
- predc and preddcounts contain the predicted class and its associated counts, respectively
- Lastly, the script outputs the predicted class as determined by the class which obtained the maximal tile votes (i.e., majority)
- This script accepts both wildcards, filenames, or a TSV file containing a subset of files to be computed (similar to HistoQC ). This can make producing results on just a testset easier.
- The resize factor should result in the same magnification/microns-per-pixel/resolution as the resize factor used in make_hdf5
- The patch size should match that used in train_densenet
I’m hoping that you find this approach useful and generalizable to your own use cases.
This project is of course a work in progress and will be updated, so please keep track of the revision numbers at the top of the files to ensure that you’re using the latest and greatest version!
Any questions or comments are of course welcome!
Link to github respository here.