Search code examples
torcheinops

einops not equivalent to torch.chunk?


I'm trying to replicate the following 2 lines in einops:

emb = emb[..., None, None]
cond_w, cond_b = th.chunk(emb, 2, dim=1)

So far, I've managed to get:

emb = rearrange(emb, "b (c h w) -> b c h w", w=1, h=1)
cond_w, cond_b = th.chunk(emb, 2, dim=1)

This works fine. But, when I do:

emb = rearrange(emb, "b (c h w) -> b c h w", w=1, h=1)
cond_w, cond_b = rearrange(emb, "b (split c) ... -> b split c ...", split=2)

The output is not the same. (Even though the shapes are). Does anyone know what's going on here?


Solution

  • Solution:

        cond_w, cond_b = rearrange(b_t, "b (split c) ... -> split b c ...", split=2)