Digital Pathology Segmentation using Pytorch + Unet

In this blog post, we discuss how to train a U-net style deep learning classifier, using Pytorch, for segmenting epithelium versus stroma regions. This post is broken down into 4 components following along other pipeline approaches we’ve discussed in the past:

  1. Making training/testing databases,
  2. Training a model,
  3. Visualizing results in the validation set,
  4. Generating output.

This model focuses on using solely Python and freely available tools (i.e., no matlab).

This blog post assumes moderate knowledge of convolutional neural networks, depending on the readers background, our JPI paper may be sufficient, or a more thorough resource such as Andrew NG’s deep learning course.

Introduction to U-Net

U-Nets were introduced here, so please refer there for a complete description. Instead, here we discuss only the high-level intuition needed to complete this tutorial.

unet

A U-net essentially involves 2 parts, a compression/encoding component, and a decompression/decoding component. The input to the network is the original RGB image, which performs semantic semegnetations and produces a binary mask (i.e., assigns each spatial location in the original image to a particular class).

The main benefit of this network architecture is found in the connection between coarse and fine spatial information. Copy and crop connections feed fine detail to the decoding layers allowing for improved localization of spatial awareness, while the compression/encoding allows for increasing levels of abstraction to take place for coarse localization. This abstraction typically occurs using convolution operations with strides greater than 1. For example, the first layer may consist of a 256x256x3 input tile, the subsequent layer being 128×128 x nkernels, 64×64 x nkernels, 32 x 32 x nkernels, etc, where nkernels is the number of unique kernels learned per layer.

During the decoding/decompression phase, an  up-convolution operation (i.e., ConvTranspose2d) takes place which essentially undoes this operation via learnable parameters, such that the 32 layer becomes 64, then 128, etc, until an output image of the original image resolution (i.e., size in pixels) is obtained, except hopefully ( ::fingers crossed:: ) this time containing a segmentation mask. Alternatively, instead of a learned up-convolution layer, an interpolation layer can be used which has no learnable parameters, but this may result in inferior results.

The network then compares the model’s output mask with the provided ground truth. Each pixel contributes to the network loss, and as such contributes to the error derivative per learn-able parameter. Since the decoding layers are provided with higher resolution kernel activations through the copy and crop layers, the network can more easily propagate context/spatial information. This will allow us to train models which have better spatial awareness than naïve pixel level classifiers applied using a sliding window.

Making a database

Regardless of the desired model type used for the classifier, deep learning is typically best performed using some type of a database backend. This database need not be sophisticated (e.g., the LMDB commonly used by Caffe), nor must it be a true database (e.g., here we’ll discuss using PyTables which has an HDF5 backend). The benefits of using a database are numerous, but we can briefly list some of them here:

  • As individual files, extracted image “patches” can often number into the thousands or millions in the case of DL. The access time can be very slow as modern operating systems are not typically tuned for efficient access of large numbers of files. Trying to do something as simple as “ls” in a directory in Windows/Linux with millions of files can cause a notable lag.
  • Improved speed as a result of both reading off of disk, o/s level caching, and improved compression
  • Improved reproducibility, coming back to a project later which has a database allows for much more consistent retraining of the model, especially when other files may have been lost (e.g., lists of training and testing sets)
  • Better protection against data leakage. Although not required, I prefer creating separate databases (in this case files), for training, validation, and testing purposes. This can greatly reduce human error in making sure that only training images are read during training time, and that validation/test images are completely held out.

I’ve written a blog post previously about the how to get images into a hdf5 table using matlab, but this blog post uses only python and is more refined. We also focus on the types of data necessary for a U-net. This comes down to storing both the original image and its associated mask into the database:

epi epi_mask

The code to do this is available here.

Some small caveats, this code chops images into non-overlapping tiles, which is very fast based on python views. If one would like to have overlapping tiles, the code itself would have to be modified and is planned for future versions.

That said, some specific items worth pointing out in the code:

  • As per cell 2, images and masks are stored as unsigned 8-bit integers, and thus their values range from [0,255]
  • Cell 3 assumes that each image represents a unique patient and thus naively splits the images into training and validation batches. If this is not the case, it should be addressed there by assigning appropriate files to each of the phase dictionary items. Always remember, training and validation should take place at a patient level (e., a patient is either in the training set or the testing set, but never both)
  • We use a modest compression level of 6, which some experiments have shown to be a nice trade-off between time and size. This is easily modifiable in the code by changed the value of “complevel”
  • Images are stored in the database in [IMAGE, NROW,NCOL,NCHANNEL] format. This allows for easier retrieval later. Likewise, the chunk size is set to the tile size such that it is compressed independently of the other items in the database, allowing for rapid data access.

Training a model

Now that we have the data ready, we’ll train the network. One of the joys of working in the DL field currently is the number of people who are willing to open source their code, in this case, the U-Net network itself was made available on github here https://github.com/jvanvugt/pytorch-unet, in a very modular an easily extendable fashion. Thus the main components that we need to develop and discuss here is how to get our data in and out of the network.

Code is available here and that dataset was previously released here.

One important practice which is commonly overlooked is to visually examine a sample of the input which will be going to the network, which we do in this cell:

unet_input

We can see that the augmentations have been done properly (both the mask and image match), and the edges have been sufficiently identified (discussed below). We can also note how the color augmentation has drastically changed the appearance of the image (left) from H&E to a greenish color space, in hopes of greater generalizability later. Here I’m using the default values as an example, but tuning these values will likely improve results if they’re tailored towards to specific test set of interest.

Some notes:

  • In the above image of the U-net, there are blocks consisting of multiple convolution layers, all of which have the same number of rows and columns. The depth specifies the number of these blocks going down, each block begins with a convolution using a stride=2 to reduce the spatial size of the input.
  • wf specifies the number of kernels which will be learned at each layer, and is a factor applied to each block, defined as 2**(wf+layer_number). Thus lower levels have a higher number of kernels.
  • The depth=6 and wf=5 parameters are the values in the original paper, but for most applications I’ve used them for this is significant overkill (number of paramters is 31043586), consider making both of them (especially wf) smaller if there are overfitting, speed, or memory issues. While the patch size is 256 here, a depth of 5 can be supported (256 / (2^5) > 1), for smaller patches a smaller depth would be required.
  • Edge_weight allows for synthetically adding additional penalty for making errors on edges. By doing so, we can help the network understand that although edge pixels are infrequent (as compared to the bulk of the object), they are very important and need to be paid special attention to (see Figure 3 here). The edges in this code are determined automatically via fast morphological operations. If more sophisticated or computationally demanding approaches are used, is not unreasonable to perform this computation once and save it inside of the DB in parallel with the mask. In this particular dataset, the annotations are very coarse, so we don’t use a strong edge weight, but when the dataset annotations are more finely produced, higher values are justified, even as high as 10 implying that edge errors contribute an order of magnitude more to the loss value.
  • Important to note for augmentation, we set the random seed and perform all affine augmentations before color augmentations, this ensures that the mask and edge images will be able to follow the same exact procedure by resetting the seed
  • In this configuration setting Num_workers>0 actually slows down the training. It seems that the overhead isn’t justified when simply loading from a pytable (which implies our DB backend is pretty fast!). Your mileage may vary.
  • This code is heavily reused for both training and valuation through the notion of “phases”, in particular cell 12 contains a bit of intertwined code which sets appropriate backend functionality (e.g., enabling/disabling gradient computation)
  • We use tensorboardX  to save real-time statistics from python for viewing in tensorboard, helping to visualize the training progress. We also break statistics down into true positives, false positives, true negatives and false negatives, to see if a certain class is being favored (which can then be adjusted using the weights in cell 10). Here we can see after about 30 epochs we’ve already reached an accuracy of 90%.

tensoboard

  • We save the best snapshot, overriding the last snapshot. Along with it, we save all information necessary to both load the network and to continue training later on
  • The bottom of the notebook shows how to both visualize individual kernels and to visualize activations. Note that to be efficient pytorch does not keep activations in memory after the network is done computing. Thus it is impossible to retrieve them after the model does its prediction. As a result, we create a hook which saves the activations we’re interested at when the layer of interest is encountered.

activiations

Visualizing results in the validation set

Since we’re consistently saving the best model as the classifier is training, we can visualize the results on the validation set easily while the network itself is training. This is best done if 2 GPUs are available, so that the main GPU can continue training while the second GPU can generate the output. If the network is small enough and the memory of a singular GPU large enough, both processes can be done using the same GPU.

A minimal example of how to do this is provided here. Note that there is a very large code overlap with the training code, except pruned down solely to generate output

Some notes:

  • Augmentation in cell 6: in the case of using this for output generation, we want to use the original images since they will give a better sense of the expected output when used on the rest of the dataset, as a result, we disable all unnecessary augmentation. The only component that remains here is the randomcrop, to ensure that regardless of the size of the image in the database, we extract an appropriately sized patch

validation_output

  • We can see the output has 4 columns as follows: (a) DL raw output, (b) DL output after argmax, (c) ground truth, (d) original image

Generating Output

Lastly, of course we’d like to apply the model to output. For that, I’ve written a prototype command line python script available here.

Given the other files, the code here should be rather self-explanatory when it comes to generating the output.

The output is generated in a naïve fashion, where a large image is chopped into tiles, those tiles fed as a batch to the model, and the results merged back together. This usually does well for most of the tile but tends to do poorly around the edges. The next version of this code (pull requests welcome!) will merge together an offset version of the image to “correct” the overlapped regions.

Some notes:

  • It accepts both wildcards, filenames, or a TSV file containing a subset of files to be computed (similar to HistoQC ). This can make producing results on just a testset easier.
  • The resize factor should result in the same magnification/microns-per-pixel/resolution as the resize factor used in make_hdf5
  • The patch size should match that used in train_unet

Conclusion 

I’m hoping that you find this approach useful and generalizable to your own use cases. I’ve realized while building this how rough the annotations in this dataset are, and yet was pleasently surprised at how the network was still able to do a pretty reasonable job in producing output.

This project is of course a work in progress and will be updated, so please keep track of the revision numbers at the top of the files to ensure that you’re using the latest and greatest version!

Any questions or comments are of course welcome!

Link to github respository here.

45 thoughts on “Digital Pathology Segmentation using Pytorch + Unet”

  1. Dear professor Andrew J.
    I have read your recently published paper’H&E-stained Whole Slide Image Deep Learning Predicts SPOP Mutation State in Prostate Cancer’, which intrigued me a lot .And I have two questions to consult you : 1.as you recommended in august, the WSI of FFPE format from TCGA was more amenable to computational analysis, but why you use frozen slides of TCGA for training and FFPE of MSK-IMPACT for testing? 2. We can download WSI from TCGA,and is there any access for me to download the WSI of MSK-IMPACT? Thank you for your consideration.

  2. I am so sorry to make such a stupid mistake. It shows on the topleft of this page ‘Andrew Janowczyk’ and the firsr author of that paper was ‘Andrew J. Schaumberg’ , which made me confused. Apologize for making this mistake again and expecting for more guidance from you in the future.Thank you.

  3. Dear Andrew,
    Firstly, your paper and codes are very nice and useful for me to dig in to computer vision field. I am appreciated for your sharing.
    However, I have a question about your sklearn (scikit-learn) version that you used in your codes to converts the images to *pytable file. Because I cann’t do from sklearn import cross_validation. Then I read docs for its version 0.20.2 and change the function to from sklearn import model_selection and after that get a bug in cell 4th of your notebook (make_hdf5.ipynb). I tried to fix the bug following scikit-learn docs and overcome this. However, I stuck on cell 6th at line 31th.
    Can you explain or update your codes?
    Linh

    1. try looking at the same file in the classification/densenet code. the relevant lines are rewritten to use the latest version of sklearn: phases[“train”],phases[“val”]=next(iter(model_selection.ShuffleSplit(n_splits=1,test_size=test_set_size).split(files)))

  4. Dear Andrew,
    I am really appreciated for sharing the nice codes. However, I have a problem with “”for ii , (X, y, y_weight) in enumerate(dataLoader[phase]): #for each of the batches”” line.
    I am getting “TypeError: self.dims,self.dims_chunk,self.maxdims cannot be converted to a Python object for pickling” error which indicates for multiprocessing library(reduction.py) line 60 (ForkingPickler(file, protocol).dump(obj)).
    I am not sure what to do and would gratitude any help.
    My environment is Windows 10, and Python 3.6m.

    1. I suppose that the error is about the difference between multithreading in linux and windows and I hope by running the code with num_workers = 0 would solve but I am not sure is it the right solution or not?

      1. yes, this is exactly the case. num_workers = 0 will get things to “work”, but will be less efficient than when num_workers>0 *and* there is significant augmentation present. realistically for num_workers <=2 i typically see things get slower instead of faster (overhead with launching forks and collecting results), so perhaps the performance hit isn't so bad. regardless, i'm hoping they'll come up with a clever way of fixing these issues and it will be resolved soon!

        1. Had been Problem of mutithreading in window solved?

          I take same error

          no of worker=0 is too late 10 hours!!

          1. there are some tricks to get it working, but ultimately i’m finding that unless the DL network is very small and the augmentation is very big, the bottleneck imposed by the batch creation is quite small. i’d suggest using line profiler first to get a feeling for which parts of the code are taking the most time before worry about optimizing

  5. Hey Andrew,
    Thanks for the well documented code on github. I am trying to use your patch extraction (hdf5 table creation) code on a different dataset – FFPE images from the genomic data commons. I run into a memory error, >120 gigs of virtual memory used, when I try to reshape the array after patch extraction. Is it possible to reshape the array and write to the hdf5 table in chunks?

    1. which line is it exactly thats causing the out of memory error? tinkering with the code should make it possible to simply change the underlying view so that no additional memory needs to be addressed. another option is to look at the output generation code (https://github.com/choosehappy/PytorchDigitalPathology/blob/master/segmentation_epistroma_unet/make_output_unet_cmd.py) which submits chunks to the GPU. the same technique could be used to batch over each e.g. column and then save that to the DB. it doesn’t matter what order things go in, as long as its in there. additionally, if your image isn’t 100% informative (e.g., there is background present) you should really be selecting “appropriate” ROIs (those which contain relevant tissue information) to try and reduce the noise in the training set

      1. The line of code giving me issues is reshaping the array : io_arr_out=io_arr_out.reshape(-1,patch_size,patch_size,3).
        Thanks for the suggestions, I’ll probably try to split io_arr_out into different chunks prior to reshaping. If that doesn’t work I’ll try the output generation example and reducing the image down to ROIs. Thanks again!

        1. Great. If you figure something out, feel free to submit the code! Always interested in improving : )

  6. this error showed up TypeError: self.dims,self.dims_chunk,self.maxdims cannot be converted to a Python object for pickling

    while I trained with the nuclei dataset from your website

    why?

    1. Are you on a windows machine? You may need to change # of workers to 0, pytorch doesn’t support forking functionality in windows very well

  7. Hi. I am newbie for studying python and Deep leargning.

    Can I ask sometion about data what I use for this code?

    What is mask?

    In make_hdf5, we use files=glob.glob(mask/*.png) and imread(“./imgs/”+os.path.basename(fname).replace(“_mask.png”,”.tif”)),cv2.COLOR_BGR2RGB) and imread(fname)/255 (fname=files(filei)).

    Then, are we using img_mask.png and img.tif file, don’t we?

    Here. img_mask.png and img.tif is same? Ask again, What is mask? is it just image, imput and processed and so on? Umm… medium result? Or is it used for convolutional filtering?

    And Deos Dataset return img_new, mask_new, weight_new? and what does DataLoader return?

    plz answer for me

    1. Correct Something.
      Then, are we using img_mask.png and img.tif file, don’t we?
      I use same img file for mask.png and img.tif.

        1. the fused image is often much easier to visualize the results from the segmentation. the binary output is typically used for downstream processes

      1. you should be able to download and run the code as is, with the dataset as provided. i would suggest not modifying anything until you at least have a base version running

    2. given your level, i would suggest you start with this manuscript and potentially this online course: https://www.coursera.org/specializations/deep-learning further, you should be able to download the datasets and step through the code to answer most of your questions. you’ll also want to look at the pytorch documentation, for example for information relating to the data loader https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

  8. Hi. I have some question.

    enumerate dataloader makes error, so i use num-of-worker=0.

    But it is slow too much.

    if trying this code at num=0, how long it takes?

    1. if you’re not using heavy augmentation with a large network, the training time will be around the same. the biggest constraint for this code is getting data on the gpu, not pulling it from the database or augmenting it. if you scroll up you’ll find a number of comments regarding the # of workers. if you’re interested in comparing, i’d consider using a linux machine where the worker parameter works

      1. if i remember, its pretty straight forward to test on windows. you need to convert the notebook to a routine python py file, and wrap the main code inside of an if __name__ == ‘__main__’ statement. then the workers will be able to fork properly

  9. Dear Andrew,

    Firstly, I am appreciated for sharing your knowledge and codes. It was very helpful and useful for me. However, I have two questions.

    1) Why are UNet often used for medical image segmentation rather than the other recent network architecture such as DeepLab?

    2) The following loss function code uses different shapes of inputs. (shape of prediction is [N, Nclass, H, W] and shape of y(mask) is [N, H, W])

    loss_matrix = criterion(prediction, y)

    My understanding is y(mask) must be transformed to one-hot encoded target vector ([N, Nclass, H, W]) for calculating pixel-wise loss. In this case, I think y (mask) shape should be [N, 2, H, W]. I’m asking because I want to change this code to multi-class segmentation.

    It would be greatly appreciated if you could explain the details.

    1. thanks for your questions. as i understand it, u-net it usually used for medical images due to its ability to learn reasonably well from small datasets. in particular, in medical sciences, obtaining a significant number of annotations or samples is challenging. additionally, the annotations themselves tend to be at least a little noisy, ultimately limiting the maximal performance of any classifier. unet has been around for a lot longer than many of these semantic segmentation approaches like deeplab. that said, higher order research approaches use the segmentation results, but aren’t very interested in developing the segmentation methods themselves. as a result, i think researchers have a lot of legacy code laying around which they are more likely to apply first. if the results are sufficient for their particular task, then they’ll move on instead of validating a number of different approaches. for example, since we have this u-net pipeline already in place, and tend to have a good feeling for how it responds and how to adjust it, we’ll use it first. to date, i’ve not encountered a situation where it performed “poorly enough” to justify investing effort in other approaches. does that make sense?

      in regards to the second question. take a look at the documentation: https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss the input and target masks can be of either style (one-hot or encoded). it is much more efficient to use the encoded style. to change the code for multi class, simply change n_classes in the training script and “classes” in the make hdf5 script, putting appropriate masks in the database, should work without issue after that.

  10. Hi, I have a question.
    In ‘train_densenet_albumentations.py’ line 175
    (img, label, img_old)=dataset[“train”][7]
    There is an error below:
    IndexError: Too many indices for object ‘/imgs’
    I have no idea how to solve this. Could you give me a favor?

    1. sorry, i just recloned the repository and reran the code and wasn’t able to reproduce your issue. are you sure you’ve run the code properly?

    1. Probably not. The magnification needed to identify cancer is usually too high such that it becomes challenging to find an appropriate patch size for the u-net. ultimately each mask ends up being almost entirely white or entirely black, which is a good indication that a classify (instead of segment) approach is more appropriate. That is precicely what I have done (see figure 7, https://www.jpathinformatics.org/article.asp?issn=2153-3539;year=2016;volume=7;issue=1;spage=29;epage=29;aulast=Janowczyk;t=6) as well as others (https://arxiv.org/abs/1606.05718). As such a good approach is to use this tutorial

    1. Thanks for your question! I’m not sure how to answer that, normally people ask questions in the positive sense (why you did something) versus the negative sense (why didn’t you) : ) Why do you think normalizing would be important/beneficial in this context?

  11. Hi, I use the code”make_output_unet_cmd.py” to generate the output but the result is an all white or all black image,I have no idea how to solve this. Could you give me an answer?thank you very much.

Leave a Reply

Your email address will not be published. Required fields are marked *