I want to train a feed-forward Neural Network with a single hidden layer that models the below equation.
h = g(W1.input1 + V1.input2 + b)
output1 = f(W2.h + b_w)
output2 = f(V2.h + b_v)
f and g
are activation functions, h
is the hidden representation, W1, W2, V1, V2
are Weight matrices, b, b_w, b_v
are respective biases.
I can't concatenate 2 inputs because that will result in a single Weight matrix. I can't train two separate NNs because the latent representation will miss the interaction between 2 inputs. Any help is much appreciated. I have also attached the NN diagram below
I decided to write my own Linear layer which calculates h = g(W1.input1 + V1.input2 + b)
I do this by creating 2 parameters W1 and V1 multiply input1 and input2 with the 2 parameters and then add everything. The code is given below:
import torch
import torch.nn as nn
import math
class MyLinearLayer(nn.Module):
def __init__(self, size_in1, size_out1):
super().__init__()
self.size_in1, self.size_out1 = size_in1, size_out1
W_1 = torch.Tensor(size_out1, size_in1)
V_1 = torch.Tensor(size_out1, size_in1)
self.W1 = nn.Parameter(W_1)
self.V1 = nn.Parameter(V_1)
bias = torch.Tensor(size_out1)
self.bias = nn.Parameter(bias)
def forward(self, x):
w_times_x= torch.mm(x[0], self.W1.t())
v_times_x= torch.mm(x[1], self.V1.t())
weight_times_x = torch.add(w_times_x, v_times_x)
return torch.add(weight_times_x, self.bias) # w times x + w times v + b
class NN(nn.Module):
def __init__(self, in_ch, h_ch, out_ch):
super().__init__()
self.input = MyLinearLayer(in_ch, h_ch)
self.W2 = nn.Linear(h_ch, out_ch)
self.V2= nn.Linear(h_ch, out_ch)
self.act = nn.ReLU()
def forward(self, i1, i2):
# I pass in stacked input
inp = torch.stack([i1,i2])
h = self.act(self.input(inp))
o1 = self.act(self.W2(h))
o2 = self.act(self.V2(h))
return o1, o2
model = NN(5, 10, 5)
o1,o2 = model(torch.rand(2, 5), torch.rand(2, 5))
for name, param in model.named_parameters():
if param.requires_grad:
print(name, '->',param.data.shape)
output 7 parameters to be trained:
input.W1 -> torch.Size([10, 5])
input.V1 -> torch.Size([10, 5])
input.bias -> torch.Size([10])
W2.weight -> torch.Size([5, 10])
W2.bias -> torch.Size([5])
V2.weight -> torch.Size([5, 10])
V2.bias -> torch.Size([5])
Thanks for all the inputs @aretor, @Ivan, and @DerekG