Search code examples
pythonprogress-barastropytqdmphotutils

How to create a progress bar for iterations happening within installed modules


I am aiming to create a progress bar for an iteration happening inside an installed module.

To create a progress bar for an iteration inside a user-defined function, I pass an tqdm.notebook.tqdm_notebook object as iterable:

import time
import numpy as np
from tqdm.notebook import tqdm

def iterate(over):
    for x in over: # creating progress bar for this
        print(x, end='')
        time.sleep(0.5)

xs = np.arange(5)
tqdm_xs = tqdm(xs) # creating tqdm.notebook.tqdm_notebook object

iterate(tqdm_xs) # progress bar, as expected
iterate(xs) # no progress bar

which works:

enter image description here

However, when I try to do the same for a for loop inside an installed module, this fails. Within Astropy's Photutils module, there is a for label in labels line (here), and I can pass the labels object.

Reproducible example (largely based on this - works after installing photutils: pip install photutils):

import photutils.datasets as phdat
import photutils.segmentation as phsegm
import astropy.convolution as conv
import astropy.stats as stats

data = phdat.make_100gaussians_image()
threshold = phsegm.detect_threshold(data, nsigma=2.)
sigma = 1.5
kernel = conv.Gaussian2DKernel(sigma, x_size=3, y_size=3)
kernel.normalize()
segm = phsegm.detect_sources(data, threshold, npixels=5, kernel=kernel)

This works:

segm_deblend = phsegm.deblend_sources(data, segm, npixels=5, kernel=kernel,
                                      nlevels=32, contrast=0.001, labels = segm.labels)

Trying to pass the tqdm.notebook.tqdm_notebook object to create progress bar:

tqdm_segm_labels = tqdm(segm.labels)
segm_deblend = phsegm.deblend_sources(data, segm, npixels=5, kernel=kernel,
                                    nlevels=32, contrast=0.001, labels = tqdm_segm_labels)

I get an AttributeError: 'int' object has no attribute '_comparable'. Full traceback:

0%
0/92 [00:00<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-8-d101466650ae> in <module>()
      1 tqdm_segm_labels = tqdm(segm.labels)
      2 segm_deblend = phsegm.deblend_sources(data, segm, npixels=5, kernel=kernel,
----> 3                                     nlevels=32, contrast=0.001, labels = tqdm_segm_labels)

4 frames
/usr/local/lib/python3.7/dist-packages/astropy/utils/decorators.py in wrapper(*args, **kwargs)
    534                     warnings.warn(message, warning_type, stacklevel=2)
    535 
--> 536             return function(*args, **kwargs)
    537 
    538         return wrapper

/usr/local/lib/python3.7/dist-packages/photutils/segmentation/deblend.py in deblend_sources(data, segment_img, npixels, kernel, labels, nlevels, contrast, mode, connectivity, relabel)
    112         labels = segment_img.labels
    113     labels = np.atleast_1d(labels)
--> 114     segment_img.check_labels(labels)
    115 
    116     if kernel is not None:

/usr/local/lib/python3.7/dist-packages/photutils/segmentation/core.py in check_labels(self, labels)
    355 
    356         # check for positive label numbers
--> 357         idx = np.where(labels <= 0)[0]
    358         if idx.size > 0:
    359             bad_labels.update(labels[idx])

/usr/local/lib/python3.7/dist-packages/tqdm/utils.py in __le__(self, other)
     70 
     71     def __le__(self, other):
---> 72         return (self < other) or (self == other)
     73 
     74     def __eq__(self, other):

/usr/local/lib/python3.7/dist-packages/tqdm/utils.py in __lt__(self, other)
     67     """Assumes child has self._comparable attr/@property"""
     68     def __lt__(self, other):
---> 69         return self._comparable < other._comparable
     70 
     71     def __le__(self, other):

AttributeError: 'int' object has no attribute '_comparable'

A workaround is just to modify Photutils and use tqdm inside it (which I did on this fork, it works), but this seems like an overkill, and I hope there is an easier way to do this.


Solution

  • Of course generally there is no way to directly modify some existing code you didn't write yourself (whether or not it's "installed" is not the issue).

    If you think it's really of general use or interest you could propose a patch to allow this function to take, e.g., a callback function to call on each loop. It might be useful if it's a slow function in general (I did notice some things in the implementation that could be changed to speed it up, but that's another matter).

    You could of course find a number of clever hacks to make it work in this one specific case, though it would be fragile considering that it's a hack designed specifically to the implementation details of this function. I found a few possibilities for this.

    The simplest seems to be this stupid trick:

    Make an ndarray subclass (I called it tqdm_array) which when iterated in Python returns an iterator over a tqdm progress bar which wraps the array itself:

    class tqdm_array(np.ndarray):
        def __iter__(self):
            return iter(tqdm.tqdm(np.asarray(self)))
    

    Then when preparing to call deblend_sources wrap your labels in this:

    labels = np.array(segm_image.labels).view(tqdm_array)
    

    and pass that to deblend_sources(..., labels=labels, ...).

    This will work because even if labels is iterated over by NumPy code it will use internal C code to iterate directly over the array buffer (e.g. for operations like labels <= 0. In most cases it won't call the Python-level __iter__ method, though there may be exceptions...

    But when encountering a Python for-loop like for label in labels: (of which there happens to be only one in this function), you'll get your progress bar.