In this blog post, we discuss how to train a U-net style deep learning classifier, using Pytorch, for segmenting epithelium versus stroma regions. This post is broken down into 4 components following along other pipeline approaches we’ve discussed in the past:
- Making training/testing databases,
- Training a model,
- Visualizing results in the validation set,
- Generating output.
This model focuses on using solely Python and freely available tools (i.e., no matlab).
This blog post assumes moderate knowledge of convolutional neural networks, depending on the readers background, our JPI paper may be sufficient, or a more thorough resource such as Andrew NG’s deep learning course.
Introduction to U-Net
U-Nets were introduced here, so please refer there for a complete description. Instead, here we discuss only the high-level intuition needed to complete this tutorial.
A U-net essentially involves 2 parts, a compression/encoding component, and a decompression/decoding component. The input to the network is the original RGB image, which performs semantic semegnetations and produces a binary mask (i.e., assigns each spatial location in the original image to a particular class).
The main benefit of this network architecture is found in the connection between coarse and fine spatial information. Copy and crop connections feed fine detail to the decoding layers allowing for improved localization of spatial awareness, while the compression/encoding allows for increasing levels of abstraction to take place for coarse localization. This abstraction typically occurs using convolution operations with strides greater than 1. For example, the first layer may consist of a 256x256x3 input tile, the subsequent layer being 128×128 x nkernels, 64×64 x nkernels, 32 x 32 x nkernels, etc, where nkernels is the number of unique kernels learned per layer.
During the decoding/decompression phase, an up-convolution operation (i.e., ConvTranspose2d) takes place which essentially undoes this operation via learnable parameters, such that the 32 layer becomes 64, then 128, etc, until an output image of the original image resolution (i.e., size in pixels) is obtained, except hopefully ( ::fingers crossed:: ) this time containing a segmentation mask. Alternatively, instead of a learned up-convolution layer, an interpolation layer can be used which has no learnable parameters, but this may result in inferior results.
The network then compares the model’s output mask with the provided ground truth. Each pixel contributes to the network loss, and as such contributes to the error derivative per learn-able parameter. Since the decoding layers are provided with higher resolution kernel activations through the copy and crop layers, the network can more easily propagate context/spatial information. This will allow us to train models which have better spatial awareness than naïve pixel level classifiers applied using a sliding window.
Making a database
Regardless of the desired model type used for the classifier, deep learning is typically best performed using some type of a database backend. This database need not be sophisticated (e.g., the LMDB commonly used by Caffe), nor must it be a true database (e.g., here we’ll discuss using PyTables which has an HDF5 backend). The benefits of using a database are numerous, but we can briefly list some of them here:
- As individual files, extracted image “patches” can often number into the thousands or millions in the case of DL. The access time can be very slow as modern operating systems are not typically tuned for efficient access of large numbers of files. Trying to do something as simple as “ls” in a directory in Windows/Linux with millions of files can cause a notable lag.
- Improved speed as a result of both reading off of disk, o/s level caching, and improved compression
- Improved reproducibility, coming back to a project later which has a database allows for much more consistent retraining of the model, especially when other files may have been lost (e.g., lists of training and testing sets)
- Better protection against data leakage. Although not required, I prefer creating separate databases (in this case files), for training, validation, and testing purposes. This can greatly reduce human error in making sure that only training images are read during training time, and that validation/test images are completely held out.
I’ve written a blog post previously about the how to get images into a hdf5 table using matlab, but this blog post uses only python and is more refined. We also focus on the types of data necessary for a U-net. This comes down to storing both the original image and its associated mask into the database:
The code to do this is available here.
Some small caveats, this code chops images into non-overlapping tiles, which is very fast based on python views. If one would like to have overlapping tiles, the code itself would have to be modified and is planned for future versions.
That said, some specific items worth pointing out in the code:
- As per cell 2, images and masks are stored as unsigned 8-bit integers, and thus their values range from [0,255]
- Cell 3 assumes that each image represents a unique patient and thus naively splits the images into training and validation batches. If this is not the case, it should be addressed there by assigning appropriate files to each of the phase dictionary items. Always remember, training and validation should take place at a patient level (e., a patient is either in the training set or the testing set, but never both)
- We use a modest compression level of 6, which some experiments have shown to be a nice trade-off between time and size. This is easily modifiable in the code by changed the value of “complevel”
- Images are stored in the database in [IMAGE, NROW,NCOL,NCHANNEL] format. This allows for easier retrieval later. Likewise, the chunk size is set to the tile size such that it is compressed independently of the other items in the database, allowing for rapid data access.
Training a model
Now that we have the data ready, we’ll train the network. One of the joys of working in the DL field currently is the number of people who are willing to open source their code, in this case, the U-Net network itself was made available on github here https://github.com/jvanvugt/pytorch-unet, in a very modular an easily extendable fashion. Thus the main components that we need to develop and discuss here is how to get our data in and out of the network.
One important practice which is commonly overlooked is to visually examine a sample of the input which will be going to the network, which we do in this cell:
We can see that the augmentations have been done properly (both the mask and image match), and the edges have been sufficiently identified (discussed below). We can also note how the color augmentation has drastically changed the appearance of the image (left) from H&E to a greenish color space, in hopes of greater generalizability later. Here I’m using the default values as an example, but tuning these values will likely improve results if they’re tailored towards to specific test set of interest.
- In the above image of the U-net, there are blocks consisting of multiple convolution layers, all of which have the same number of rows and columns. The depth specifies the number of these blocks going down, each block begins with a convolution using a stride=2 to reduce the spatial size of the input.
- wf specifies the number of kernels which will be learned at each layer, and is a factor applied to each block, defined as 2**(wf+layer_number). Thus lower levels have a higher number of kernels.
- The depth=6 and wf=5 parameters are the values in the original paper, but for most applications I’ve used them for this is significant overkill (number of paramters is 31043586), consider making both of them (especially wf) smaller if there are overfitting, speed, or memory issues. While the patch size is 256 here, a depth of 5 can be supported (256 / (2^5) > 1), for smaller patches a smaller depth would be required.
- Edge_weight allows for synthetically adding additional penalty for making errors on edges. By doing so, we can help the network understand that although edge pixels are infrequent (as compared to the bulk of the object), they are very important and need to be paid special attention to (see Figure 3 here). The edges in this code are determined automatically via fast morphological operations. If more sophisticated or computationally demanding approaches are used, is not unreasonable to perform this computation once and save it inside of the DB in parallel with the mask. In this particular dataset, the annotations are very coarse, so we don’t use a strong edge weight, but when the dataset annotations are more finely produced, higher values are justified, even as high as 10 implying that edge errors contribute an order of magnitude more to the loss value.
- Important to note for augmentation, we set the random seed and perform all affine augmentations before color augmentations, this ensures that the mask and edge images will be able to follow the same exact procedure by resetting the seed
- In this configuration setting Num_workers>0 actually slows down the training. It seems that the overhead isn’t justified when simply loading from a pytable (which implies our DB backend is pretty fast!). Your mileage may vary.
- This code is heavily reused for both training and valuation through the notion of “phases”, in particular cell 12 contains a bit of intertwined code which sets appropriate backend functionality (e.g., enabling/disabling gradient computation)
- We use tensorboardX to save real-time statistics from python for viewing in tensorboard, helping to visualize the training progress. We also break statistics down into true positives, false positives, true negatives and false negatives, to see if a certain class is being favored (which can then be adjusted using the weights in cell 10). Here we can see after about 30 epochs we’ve already reached an accuracy of 90%.
- We save the best snapshot, overriding the last snapshot. Along with it, we save all information necessary to both load the network and to continue training later on
- The bottom of the notebook shows how to both visualize individual kernels and to visualize activations. Note that to be efficient pytorch does not keep activations in memory after the network is done computing. Thus it is impossible to retrieve them after the model does its prediction. As a result, we create a hook which saves the activations we’re interested at when the layer of interest is encountered.
Visualizing results in the validation set
Since we’re consistently saving the best model as the classifier is training, we can visualize the results on the validation set easily while the network itself is training. This is best done if 2 GPUs are available, so that the main GPU can continue training while the second GPU can generate the output. If the network is small enough and the memory of a singular GPU large enough, both processes can be done using the same GPU.
A minimal example of how to do this is provided here. Note that there is a very large code overlap with the training code, except pruned down solely to generate output
- Augmentation in cell 6: in the case of using this for output generation, we want to use the original images since they will give a better sense of the expected output when used on the rest of the dataset, as a result, we disable all unnecessary augmentation. The only component that remains here is the randomcrop, to ensure that regardless of the size of the image in the database, we extract an appropriately sized patch
- We can see the output has 4 columns as follows: (a) DL raw output, (b) DL output after argmax, (c) ground truth, (d) original image
Lastly, of course we’d like to apply the model to output. For that, I’ve written a prototype command line python script available here.
Given the other files, the code here should be rather self-explanatory when it comes to generating the output.
The output is generated in a naïve fashion, where a large image is chopped into tiles, those tiles fed as a batch to the model, and the results merged back together. This usually does well for most of the tile but tends to do poorly around the edges. The next version of this code (pull requests welcome!) will merge together an offset version of the image to “correct” the overlapped regions.
- It accepts both wildcards, filenames, or a TSV file containing a subset of files to be computed (similar to HistoQC ). This can make producing results on just a testset easier.
- The resize factor should result in the same magnification/microns-per-pixel/resolution as the resize factor used in make_hdf5
- The patch size should match that used in train_unet
I’m hoping that you find this approach useful and generalizable to your own use cases. I’ve realized while building this how rough the annotations in this dataset are, and yet was pleasently surprised at how the network was still able to do a pretty reasonable job in producing output.
This project is of course a work in progress and will be updated, so please keep track of the revision numbers at the top of the files to ensure that you’re using the latest and greatest version!
Any questions or comments are of course welcome!
Link to github respository here.