When I am writing a class inherited from torch.autograd.Function, I found that in backward function, the data printed out in python([[[1., 1., 1., 1., 1., 1., 1., 1.]]]) and fetched in beneath the cuda code([[[1., 0., 0., 0., 0., 0., 0., 0.]]]) are not the same unless I clone the output_grad. I wonder what might be problem there. Thanks.
class FusedRotaryEmbeddingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, position_ids, tensor_index, k_size, rotary_size, base):
ctx.save_for_backward(cos, sin, position_ids)
ctx.tensor_index = tensor_index
ctx.k_size = k_size
ctx.rotary_size = rotary_size
ctx.base = base
return fused_apply_rotary_emb_cuda.forward(x, cos, sin, position_ids, tensor_index, k_size, rotary_size, base)
@staticmethod
def backward(ctx, output_grad):
cos, sin, position_ids = ctx.saved_tensors
tensor_index = ctx.tensor_index
k_size = ctx.k_size
rotary_size = ctx.rotary_size
base = ctx.base
# incorrect result, unless I input output_grad.clone()
x_grad = fused_apply_rotary_emb_cuda.backward(output_grad, cos, sin, position_ids, tensor_index, k_size, rotary_size, base)
return (x_grad, None, None, None, None, None, None, None)
Actually, I have found out that it is because pytorch tensors have different data layouts beneath, and the output_grad generated by its downstream kernel has in fact a non-contiguous layout, and therefore when i am accessing the data using the kernel implemented by myself in cuda contiguously, it gives me an unexpected result.
So, to avoid incompatible data access pattern with tensor layout, I should convert the output_grad into contiguous layout like this.
@staticmethod
def backward(ctx, output_grad):
cos, sin, position_ids = ctx.saved_tensors
tensor_index = ctx.tensor_index
k_size = ctx.k_size
rotary_size = ctx.rotary_size
base = ctx.base
x_grad = fused_apply_rotary_emb_cuda.backward(output_grad.contiguous(), cos, sin, position_ids, tensor_index, k_size, rotary_size, base)
return (x_grad, None, None, None, None, None, None, None)
I guess .clone() also works fine because it default creates a contiguous tensor, and copies the data into it.