Search code examples
pythonalgorithmpermutationin-place

Determining processing order for (almost) in-place computation


I have an i channel image, image. I also have f filters, filters, that can be applied to a channel. I want to generate an o channel image, output, by selectively applying filters to the channels of the image. I currently have this defined with two lists, image_idx and filter_idx, so that processing is done as

for j in xrange(o) :
    output[j] = filter[filter_idx[j]](image[image_idx[j]])

Because the images can be pretty large, I may want to do this processing in-place. This may require processing the channels in a specific order, to avoid writing over data that you will need later. I am currently checking if such an order existes, and computing it, with the following function:

def in_place_sequence(indices) :
    """
    Figure out a processing sequence for in-place computation.
    """
    indices = [j for j in indices]
    positions = set(range(len(indices)))
    processing_order = []
    change = True
    while change and len(positions) :
        change = False
        for j in list(positions) :
            val = indices[j]
            if (j not in indices) or (indices.count(val) == 1 and val == j) :
                indices[j] = None
                positions.remove(j)
                processing_order.append(j)
                change = True
    if len(positions) :
        return None
    return processing_order

For example:

In [21]: in_place_sequence([4, 0, 3, 0, 4])
Out[21]: [1, 2, 3, 0, 4]

And a possible processing order to avoid overwriting would be:

img[0] -> img[1]
img[3] -> img[2]
img[0] -> img[3]
img[4] -> img[0]
img[4] -> img[4]

This is implemented something like:

for j in in_place_sequence(image_idx) :
    image[j] = filter[filter_idx[j]](image[image_idx[j]])

I am starting to hint that, when it fails to find a sequence, is because image_idx defines a closed loop permutation. For instance:

In [29]: in_place_sequence([2, 0, 3, 1])

returns None, but it could still be done in-place with minimal storage of 1 channel:

img[0] -> temp
img[2] -> img[0]
img[3] -> img[2]
img[1] -> img[3]
temp   -> img[1]

I am having trouble though in figuring out a way to implement this automatically. I think thee way to go would be to keep my current algorithm, and if it fails to exhaust positions, figure out the closed loops and do something like the above for each of them. I have the impression, though, that I may be reinventing the wheel here. So before diving into coding that, I thought I'd ask: what is the best way of determining the processing order to minimize intermediate storage?


EDIT On Sam Mussmann's encouragement, I have gone ahead and figured out the remaining cycles. My code now looks like this:

def in_place_sequence(indices) :
    """
    Figures out a processing sequence for in-place computation.

    Parameters
    ----------
     indices : array-like
         The positions that the inputs will take in the output after
         processing.

    Returns
    -------
     processing_order : list
         The order in which output should be computed to avoid overwriting
         data needed for a later computation.

     cycles : list of lists
         A list of cycles present in `indices`, that will require a one
         element intermediate storage to compute in place.

    Notes
    -----
    If not doing the opearation in-place, if `in_` is a sequence of elements
    to process with a function `f`, then `indices` would be used as follows to
    create the output `out`:

        >>> out = []
        >>> for idx in indices :
        ...     out.append(f(in_[idx]))

    so that `out[j] = f(in_[indices[j]])`.

    If the operation is to be done in-place, `in_place_sequence` could be used
    as follows:

        >>> sequence, cycles = in_place_sequence(indices)
        >>> for j, idx in enumerate(sequence) :
        ...     in_[j] = f(in_[idx])
        >>> for cycle in cycles :
        ...     temp = in_[cycle[0]]            
        ...     for to, from_ in zip(cycle, cycle[1:]) :
        ...         in_[to] = f(in_[from_])
        ...     in_[cycle[-1]] = f(temp)
    """
    indices = [j for j in indices]
    print indices
    positions = set(range(len(indices)))
    processing_order = []
    change = True
    while change and positions :
        change = False
        for j in list(positions) :
            val = indices[j]
            if (j not in indices) or (indices.count(val) == 1 and val == j) :
                indices[j] = None
                positions.remove(j)
                processing_order.append(j)
                change = True
    cycles = []
    while positions :
        idx = positions.pop()
        start = indices.index(idx)
        cycle = [start]
        while idx != start :
            cycle.append(idx)
            idx = indices[idx]
            positions.remove(idx)
        cycles.append(cycle)
    return processing_order, cycles

Solution

  • I think your method is as good as you'll get.

    Think of a representation of your indices list as a directed graph, where each channel is a node, and an edge (u, v) represents that output channel v depends on input channel u.

    As written, your solution finds a node that has no outbound edges, removes this node and its incoming edge, and repeats until can't remove any more nodes. If there are no more nodes left, you're done -- if there are nodes left, you're stuck.

    In our graph representation, being stuck means that there is a cycle. Adding a temporary channel let's us "split" a node and break the cycle.

    At this point, though, we might want to get smart. Is there any node that we could split that would break more than one cycle? The answer, unfortunately, is no. Each node has only one inbound edge because each output channel v can only depend on one input channel. In order for a node to be part of multiple cycles, it (or some other node) would have to have two inbound edges.

    So, we can break each cycle by adding a temporary channel, and adding a temporary channel can only break one cycle.

    Furthermore, when all you have left is cycles, splitting any node will break one of the cycles. So you don't need any fancy heuristics. Just run the code you have now until it's done -- if there are still positions left, add a temporary channel and run your code again.