Search code examples
neural-networkdeep-learningbigdatadeep-residual-networksgeneric-derivation

Clarification on NN residual layer back-prop derivation


I've looked everywhere and can't find anything that explains the actual derivation of backprop for residual layers. Here's my best attempt and where I'm stuck. It is worth mentioning that the derivation that I'm hoping for is from a generic perspective that need not be limited to convolutional NNs.

If the formula for calculating the output of a normal hidden layer is F(x) then the formula for a hidden layer with a residual connection is F(x) + o, where x is the weight adjusted output of a previous layer, o is the output of a previous layer, and F is the activation function. To get the delta for a normal layer during back-propagation one needs to calculate the gradient of the output ∂F(x)/∂x. For a residual layer this is ∂(F(x) + o)/∂x which is separable into ∂F(x)/∂x + ∂o/∂x (1).

If all of this is correct, how does one deal with ∂o/∂x? It seems to me that it depends on how far back in the network o comes from.

  • If o is just from the previous layer then o*w=x where w are the weights connecting the previous layer to the layer for F(x). Taking the derivative of each side relative to o gives ∂(o*w)/∂o = ∂x/∂o, and the result is w = ∂x/do which is just the inverse of the term that comes out at (1) above. Does it make sense that in this case the gradient of the residual layer is just ∂F(x)/∂x + 1/w ? Is it accurate to interpret 1/w as a matrix inverse? If so then is that actually getting computed by NN frameworks that use residual connections or is there some shortcut that is for adding in the error from the residual?

  • If o is from further back in the network then, I think, the derivation becomes slightly more complicated. Here is an example where the residual comes from one layer further back in a network. The network architecture is Input--w1--L1--w2--L2--w3--L3--Out, having a residual connection from the L1 to L3 layers. The symbol o from the first example is replaced by the layer output L1 for unambiguity. We are trying to calculate the gradient at L3 during back-prop which has a forward function of F(x)+L1 where x=F(F(L1*w2)*w3). The derivative of this relationship is ∂x/∂L1=∂F(F(L1*w2)*w3/∂L1, which is more complicated but doesn't seem too difficult to solve numerically.

If the above derivation is reasonable then it's worth noting that there is a case when the derivation fails, and that is when a residual connection originates from the Input layer. This is because the input cannot be broken down into a o*w=x expression (where x would be the input values). I think this must suggest that residual layers cannot originate from from the input layer, but since I've seen network architecture diagrams that have residual connections that originate from the input, this casts my above derivations into doubt. I can't see where I've gone wrong though. If anyone can provide a derivation or code sample for how they calculate the gradient at residual merge points correctly, I would be deeply grateful.

EDIT:

The core of my question is, when using residual layers and doing vanilla back-propagation, is there any special treatment of the error at the layers where residuals are added? Since there is a 'connection' between the layer where the residual comes from and the layer where it is added, does the error need to get distributed backwards over this 'connection'? My thinking is that since residual layers provide raw information from the beginning of the network to deeper layers, the deeper layers should provide raw error to the earlier layers.

Based on what I've seen (reading the first few pages of googleable forums, reading the essential papers, and watching video lectures) and Maxim's post down below, I'm starting to think that the answer is that ∂o/∂x = 0 and that we treat o as a constant.

Does anyone do anything special during back-prop through a NN with residual layers? If not, then does that mean residual layers are an 'active' part of the network on only the forward pass?


Solution

  • I think you've over-complicated residual networks a little bit. Here's the link to the original paper by Kaiming He at al.

    In section 3.2, they describe the "identity" shortcuts as y = F(x, W) + x, where W are the trainable parameters. You can see why it's called "identity": the value from the previous layer is added as is, without any complex transformation. This makes two things:

    • F now learns the residual y - x (discussed in 3.1), in short: it's easier to learn.
    • The network gets an extra connection to the previous layer, which improves the gradients flow.

    The backward flow through the identity mapping is trivial: the error message is passed unchanged, no inverse matrices are involved (in fact, they are not involved in any linear layer).

    Now, paper authors go a bit further and consider a slightly more complicated version of F, which changes the output dimensions (which probably you had in mind). They write it generally as y = F(x, W) + Ws * x, where Ws is the projection matrix. Note that, though it's written as matrix multiplication, this operation is in fact very simple: it adds extra zeros to x to make its shape larger. You can read a discussion of this operation in this question. But this does very few changes the backward: the error message is simply clipped to the original shape of x.