Search code examples
pythonmatplotlibtextplot-annotations

Annotation auto-placement in a plot


I have code that handles coloring and plotting multiple plots automatically easily (for me). I want to make annotation easier:

goal: If an annotation xy conflicts with a previous one, shift - say up - until there is no conflict with no other annotation.

  1. If there is a function already capable of this that would be a dream, I couldn't find one.

  2. Otherwise - what's the best way to list annotations and get their bounding box in the coordinate system?

I have a code for auto coloring that looks like this:

if chain:
    children = []
    for child in Iplot.axes.get_children():
        if (type(child) is not matplotlib.collections.PathCollection and
            type(child) is not matplotlib.lines.Line2D):
            continue
        children.append(child)
    col_step = 1.0/(len(children)+len(args))
    for child in children:
        child.set_color([Iplot.col,0,1-Iplot.col])
        Iplot.col += col_step

I could do something similar for annotations (change if statement and body of second loop), but 1) I don't like this piece of code 2) I'm hoping something more elegant exists.


Solution

  • This is my solution, one I was trying to avoid. I saw in the old question linked in the question comments someone mentions this is an np-complete problem, but I want to point out that's doesn't really matter. I tested up to 26 annotations, and it takes a few seconds but no more. Any practical plot won't have 1000 annotations.

    Caveats:

    1. As mentioned, this isn't superfast. Specifically I wish I could avoid draw(). It's OK now, only draw twice.
    2. This code allows any new annotation/s to be added only with a specific orthogonal (left/right/up/down) direction, but this can be extended.
    3. The arrow placement is window dependent. This means make sure the window size or axes do not change after annotating (with arrows). Re-annotate if you resize.
    4. Not dynamic, see point 3.

    Background:

    1. Iplot is a helper class I have to handle multiplot figures, handling coloring, sizing and now annotating.
    2. plt is matplotlib.pyplot
    3. This methods handles multiple annotations (or single) and can now solve conflicts.
    4. As you may have guessed, Iplot.axes holds my axes figure.

    EDIT I removed my class code to make this more copy pasteable. Axes should be given to the function, and kwargs accept an existing boxes keyword to take into account previous annotations, which is edited in place. Note I use a class to encapsulate this. The function has to return the boxes for use as well, in case this is a first call.

    EDIT 2

    After a while sped this up - no need to draw so much, better to loop twice and then update the renderer in between.

    The code:

    def annotate(axes,boxes,labels,data,**kwargs):
        #slide should be relevant edge of bbox - e.g. (0,0) for left, (0,1) for bottom ...
        try: slide = kwargs.pop("slide")
        except KeyError: slide = None
        try: 
            xytexts = kwargs.pop("xytexts")
            xytext  = xytexts
        except KeyError: 
            xytext = (0,2)
            xytexts = None
        try: boxes = kwargs.pop("boxes")
        except KeyError: boxes = list()
        pixel_diff = 1
                                                                                      newlabs = []              
        for i in range(len(labels)):
            try: 
                len(xytexts[i])
                xytext = xytexts[i]
            except TypeError: pass
    
            a = axes.annotate(labels[i],xy=data[i],textcoords='offset pixels',
                                        xytext=xytext,**kwargs)
            newlabs.append(a)
        plt.draw()
        for i in range(len(labels)):
            cbox = a.get_window_extent()
            if slide is not None:
                direct  = int((slide[0] - 0.5)*2)
                current = -direct*float("inf")
                arrow = False
                while True:
                    overlaps = False
                    count = 0
                    for box in boxes:
                        if cbox.overlaps(box):
                            if direct*box.get_points()[slide] > direct*current:
                                overlaps = True
                                current =  box.get_points()[slide] 
                                shift   = direct*(current - cbox.get_points()[1-slide[0],slide[1]])
                    if not overlaps: break
                    arrow = True
                    position = array(a.get_position())
                    position[slide[1]] += shift * direct * pixel_diff
                    a.set_position(position)
                    plt.draw()
                    cbox = a.get_window_extent()
                    x,y =  axes.transData.inverted().transform(cbox)[0]
                if arrow:
                    axes.arrow(x,y,data[i][0]-x,data[i][1]-y,head_length=0,head_width=0)
            boxes.append(cbox)
        plt.draw()
        return boxes
    

    Any suggestions to improve will be warmly welcomed. Many thanks!