lately I have been developing a function capable of dealing with tensors with dimension:
torch.Size([51, 265, 23, 23])
where the first dim is time, the second is pattern and the last 2 are pattern size.
Each individual pattern can have a maximum of 3 states: [-1,0,1], and it is considered 'alive' meanwhile a pattern is 'dead' in all other cases where it doesn't have all 3 states.
my objective is to filter all the dead patterns by checking the last row (last time step) of the tensor.
def filter_patterns(tensor_sims):
# Get the indices of the columns that need to be kept
keep_indices = torch.tensor([i for i in
range(tensor_sims.shape[1]) if
tensor_sims[-1,i].unique().numel() == 3])
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, keep_indices]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
Unfortunately I'm not able to get rid of the for loop.
I tried to play around with the torch.unique() function and with the parameter dim, I tried reducing the dimensions of the tensor and flattening, but nothing worked.
def filter_patterns(tensor_sims):
# Flatten the spatial dimensions of the last timestep
x_ = tensor_sims[-1].flatten(1)
# Create masks to identify -1, 0, and 1 conditions
mask_minus_one = (x_ == -1).any(dim=1)
mask_zero = (x_ == 0).any(dim=1)
mask_one = (x_ == 1).any(dim=1)
# Combine the masks using logical_and
mask =
mask_minus_one.logical_and(mask_zero).logical_and(mask_one)
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, mask]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
the new implementation is extremely faster.
I don't believe you can get away with torch.unique
because it won't work per column. Instead of iterating over dim=1
you could construct three mask tensors to check for -1
, 0
, and 1
values, respectively. To compute the resulting column mask, you can get away with some basic logic when combining the masks:
Considering you only check on the last timestep, focus on that and flatten the spatial dimensions:
x_ = x[-1].flatten(1)
The three masks to identify -1
, 0
, and 1
conditions can be obtained with: x_ == -1
, x_ == 0
, and x_ == 1
, respectively. Combine them with torch.logical_or
mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)
Finally, check that all elements are True
across rows:
keep_indices = mask.all(dim=1)