Search code examples
juliaflux.jl

How do I add a Batch Normalization layer to my model in Flux.jl


I have a simple model which I defined but I want it to use batch normalization so I don't have to calculate and provide the normalizations manually. The model currently looks like:

m = Chain(
  Dense(28^2, 64),
  Dense(64, 10),
  softmax)

How can I edit this model to add a BN layer or define a new one all together?


Solution

  • Using Flux.jl's built in Batch Normalization function you can do the following:

    m = Chain(
      Dense(28^2, 64),
      BatchNorm(64, relu),
      Dense(64, 10),
      BatchNorm(10),
      softmax)
    

    where relu is the element-wise activation that takes place after normalization. You can read why we use the relu function in an example like this: https://stats.stackexchange.com/questions/226923/why-do-we-use-relu-in-neural-networks-and-how-do-we-use-it. Find out more about the BatchNorm function in the Flux.jl docs.