Search code examples
trax

trax tl.Relu and tl.ShiftRight layers are nested inside Serial Combinator


I am trying to build an attention model but Relu and ShiftRight layer by default nested inside the Serial Combinator. This further gives me errors in training.

layer_block = tl.Serial(
    tl.Relu(),
    tl.LayerNorm(), )

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]]).astype(np.float32) 

layer_block.init(shapes.signature(x)) y = layer_block(x)

print(f'layer_block: {layer_block}')

Output

layer_block: Serial[
  Serial[
    Relu
  ]
  LayerNorm
]

Expected Output

layer_block: Serial[
  Relu
  LayerNorm
]

The same problem arises with tl.ShiftRight()

The code above is taken from official documentation Example 5

Thanks in advance


Solution

  • I could not found the exact solution to the above problem, but you can create a custom Function using tl.Fn() and add the Relu and ShiftRight function code in it.

    def _zero_pad(x, pad, axis):
        """Helper for jnp.pad with 0s for single-axis case."""
        pad_widths = [(0, 0)] * len(x.shape)
        pad_widths[axis] = pad  # Padding on axis.
        
        return jnp.pad(x, pad_widths, mode='constant')
    
    
    def f(x):
        if mode == 'predict':
            return x
        padded = _zero_pad(x, (n_positions, 0), 1)
        return padded[:, :-n_positions]
    
    # set ShiftRight parameters as global 
    n_positions = 1
    mode='train'
    
    layer_block = tl.Serial(
        tl.Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)),
        tl.LayerNorm(),
        tl.Fn(f'ShiftRight({n_positions})', f)
    )
    
    
    x = np.array([[-2, -1, 0, 1, 2],
                  [-20, -10, 0, 10, 20]]).astype(np.float32)
    layer_block.init(shapes.signature(x))
    y = layer_block(x)
    
    
    print(f'layer_block: {layer_block}')
    

    Output

    layer_block: Serial[
      Relu
      LayerNorm
      ShiftRight(1)
    ]