def AdaIN(x):
#Normalize x[0] (image representation)
mean = K.mean(x[0], axis = [1, 2], keepdims = True)
std = K.std(x[0], axis = [1, 2], keepdims = True) + 1e-7
y = (x[0] - mean) / std
#Reshape scale and bias parameters
pool_shape = [-1, 1, 1, y.shape[-1]]
scale = K.reshape(x[1], pool_shape)
bias = K.reshape(x[2], pool_shape)#Multiply by x[1] (GAMMA) and add x[2] (BETA)
return y * scale + bias
def g_block(input_tensor, latent_vector, filters):
gamma = Dense(filters, bias_initializer = 'ones')(latent_vector)
beta = Dense(filters)(latent_vector)
out = UpSampling2D()(input_tensor)
out = Conv2D(filters, 3, padding = 'same')(out)
out = Lambda(AdaIN)([out, gamma, beta])
out = Activation('relu')(out)
return out
Please see code above. I am currently studying styleGAN. I am trying to convert this code into pytorch but I cant seem to understand what does Lambda do in g_block. AdaIN needs only one input based on its declaration but some how is gamma and beta also used as input? Please inform me what does the Lambda do in this code.
Thank you very much.
Lambda layers in keras
are used to call custom functions inside the model. In g_block
Lambda
calls AdaIN
function and passes out, gamma, beta
as arguments inside a list. And AdaIN
function receives these 3 tensors encapsulated within a single list as x
. And also those tensors are accessed inside AdaIN
function by indexing list x
(x[0], x[1], x[2]).
Here's pytorch
equivalent:
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaIN(nn.Module):
def forward(self, out, gamma, beta):
bs, ch = out.size()[:2]
mean = out.reshape(bs, ch, -1).mean(dim=2).reshape(bs, ch, 1, 1)
std = out.reshape(bs, ch, -1).std(dim=2).reshape(bs, ch, 1, 1) + 1e-7
y = (out - mean) / std
bias = beta.unsqueeze(-1).unsqueeze(-1).expand_as(out)
scale = gamma.unsqueeze(-1).unsqueeze(-1).expand_as(out)
return y * scale + bias
class g_block(nn.Module):
def __init__(self, filters, latent_vector_shape, input_tensor_channels):
super().__init__()
self.gamma = nn.Linear(in_features = latent_vector_shape, out_features = filters)
# Initializes all bias to 1
self.gamma.bias.data = torch.ones(filters)
self.beta = nn.Linear(in_features = latent_vector_shape, out_features = filters)
# calculate appropriate padding
self.conv = nn.Conv2d(input_tensor_channels, filters, 3, 1, padding=1)# calc padding
self.adain = AdaIN()
def forward(self, input_tensor, latent_vector):
gamma = self.gamma(latent_vector)
beta = self.beta(latent_vector)
# check default interpolation mode in keras and replace mode below if different
out = F.interpolate(input_tensor, scale_factor=2, mode='nearest')
out = self.conv(out)
out = self.adain(out, gamma, beta)
out = torch.relu(out)
return out
# Sample:
input_tensor = torch.randn((1, 3, 10, 10))
latent_vector = torch.randn((1, 5))
g = g_block(3, latent_vector.shape[1], input_tensor.shape[1])
out = g(input_tensor, latent_vector)
print(out)
Note: you need to pass latent_vector
and input_tensor
shapes while creating g_block
.