Search code examples
pythonpytorch

How to efficiently implement forward fill in pytorch


How can I efficiently implement the fill forward logic (inspired for pandas ffill) for a vector shaped NxLxC (batch, sequence dimension, channel). Because each channel sequence is independent this can be equivalent to working with a tensor shaped (N*C)xL.

The computation should keep the torch variable so that the actual output is differentiable.

I managed to make something with advanced indexing, but it is L**2 in the memory and number of operations, so not very great and gpu friendly.


Example:

Assuming you have the sequence [0,1,2,0,0,3,0,4,0,0,0,5,6,0] in a tensor shaped 1x14 the fill forward will give you the sequence [0,1,2,2,2,3,3,4,4,4,4,5,6,6].

An other example shaped 2x4 is [[0, 1, 0, 3], [1, 2, 0, 3]] which should be forward filled into [[0, 1, 1, 3], [1, 2, 2, 3]].


Method used today:

We use the following code that is highly unoptimized but still faster than non vectorized loops:

def last_zero_sequence_start_indices(t: torch.Tensor) -> torch.Tensor:
    """
    Given a 3D tensor `t`, this function returns a two-dimensional tensor where each entry represents
    the starting index of the last contiguous sequence of zeros up to and including the current index.
    If there's no zero at the current position, the value is the tensor's length.

    In essence, for each position in `t`, the function pinpoints the beginning of the last contiguous
    sequence of zeros up to that position.

    Args:
    - t (torch.Tensor): Input tensor with shape [Batch, Channel, Time].

    Returns:
    - torch.Tensor: Three-dimensional tensor with shape [Batch, Channel, Time] indicating the starting position of
        the last sequence of zeros up to each index in `t`.
    """

    # Create a mask indicating the start of each zero sequence
    start_of_zero_sequence = (t == 0) & torch.cat([
        torch.full(t.shape[:-1] + (1,), True, device=t.device),
        t[..., :-1] != 0,
    ], dim=2)

    # Duplicate this mask into a TxT matrix
    duplicated_mask = start_of_zero_sequence.unsqueeze(2).repeat(1, 1, t.size(-1), 1)

    # Extract the lower triangular part of this matrix (including the diagonal)
    lower_triangular = torch.tril(duplicated_mask)

    # For each row, identify the index of the rightmost '1' (start of the last zero sequence up to that row)
    indices = t.size(-1) - 1 - lower_triangular.int().flip(dims=[3]).argmax(dim=3)

    return indices

Solution

  • Here is an approach to this problem, without creating TxT matrix:

    import torch
    def forward_fill(t: torch.Tensor) -> torch.Tensor:
        n_dim, t_dim = t.shape
        # Generate indices range
        rng = torch.arange(t_dim)
        
        rng_2d = rng.unsqueeze(0).repeat(n_dim, 1)
        # Replace indices to zero for elements that equal zero
        rng_2d[t == 0] = 0
        
        # Forward fill of indices range so all zero elements will be replaced with previous non-zero index.
        idx = rng_2d.cummax(1).values
        t = t[torch.arange(n_dim)[:, None], idx]
        return t
    

    Note that this is a solution for 2D input but can be easily modified for more dimensions.