I want to implement backward function of conv2d.
Here is an example of a linear function:
# Inherit from Function
class LinearFunction(Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
I don't think I have a clear understanding of this function yet.
How can I implement conv2d backward function?
This is an implementation of the conv2d backward function:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from torch.nn import init
import math
import numpy as np
class Conv2dFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
output = F.conv2d(input, weight, bias, stride, padding, dilation, groups)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
dilation = ctx.dilation
groups = ctx.groups
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0,2,3)).squeeze(0)
return grad_input, grad_weight, grad_bias, None, None, None, None