I have 3 VGG: VGGA, VGGB and VGG*, trained with the following training function:
def train(nets, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
# try:
nets = [n.to(dev) for n in nets]
model_a = module_unwrap(nets[0], True)
model_b = module_unwrap(nets[1], True)
model_c = module_unwrap(nets[2], True)
reg_loss = nn.MSELoss()
criterion.to(dev)
reg_loss.to(dev)
# Initialize history
history_loss = {"train": [], "val": [], "test": []}
history_accuracy = {"train": [], "val": [], "test": []}
# Store the best val accuracy
best_val_accuracy = 0
# Process each epoch
for epoch in range(epochs):
# Initialize epoch variables
sum_loss = {"train": 0, "val": 0, "test": 0}
sum_accuracy = {"train": [0,0,0], "val": [0,0,0], "test": [0,0,0]}
progbar = None
# Process each split
for split in ["train", "val", "test"]:
if split == "train":
for n in nets:
n.train()
widgets = [
' [', pb.Timer(), '] ',
pb.Bar(),
' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')
]
progbar = pb.ProgressBar(max_value=len(loaders[split][0]),widgets=widgets,redirect_stdout=True)
else:
for n in nets:
n.eval()
# Process each batch
for j,((input_a, labels_a),(input_b, labels_b)) in enumerate(zip(loaders[split][0],loaders[split][1])):
input_a = input_a.to(dev)
input_b = input_b.to(dev)
labels_a = labels_a.long().to(dev)
labels_b = labels_b.long().to(dev)
#print(labels_a.shape)
#labels_a = labels_a.squeeze()
#labels_b = labels_b.squeeze()
#labels_a = labels_a.unsqueeze(1)
#labels_b = labels_b.unsqueeze(1)
#print(labels_a.shape)
#labels_a = labels_a.argmax(-1)
#labels_b = labels_b.argmax(-1)
inputs = torch.cat([input_a,input_b],axis=0)
labels = torch.cat([labels_a, labels_b])
#labels = labels.squeeze()
#print(labels.shape)
#labels = labels.argmax(-1)
# Reset gradients
optimizer.zero_grad()
# Compute output
features_a = nets[0](input_a)
features_b = nets[1](input_b)
features_c = nets[2](inputs)
pred_a = torch.squeeze(nets[3](features_a))
pred_b = torch.squeeze(nets[3](features_b))
pred_c = torch.squeeze(nets[3](features_c))
loss = criterion(pred_a, labels_a) + criterion(pred_b, labels_b) + criterion(pred_c, labels)
for n in model_a:
layer_a = model_a[n]
layer_b = model_b[n]
layer_c = model_c[n]
if (isinstance(layer_a,nn.Conv2d)):
loss += lambda_reg * reg_loss(combo_fn(layer_a.weight,layer_b.weight),layer_c.weight)
if (layer_a.bias is not None):
loss += lambda_reg * reg_loss(combo_fn(layer_a.bias, layer_b.bias), layer_c.bias)
# Update loss
sum_loss[split] += loss.item()
# Check parameter update
if split == "train":
# Compute gradients
loss.backward()
# Optimize
optimizer.step()
# Compute accuracy
#https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
#pred_labels_a = (pred_a >= 0.0).long() # Binarize predictions to 0 and 1
#pred_labels_b = (pred_b >= 0.0).long() # Binarize predictions to 0 and 1
#pred_labels_c = (pred_c >= 0.0).long() # Binarize predictions to 0 and 1
#print(pred_a.shape)
_,pred_label_a = torch.max(pred_a, dim = 1)
pred_labels_a = (pred_label_a == labels_a).float()
_,pred_label_b = torch.max(pred_b, dim = 1)
pred_labels_b = (pred_label_b == labels_b).float()
_,pred_label_c = torch.max(pred_c, dim = 1)
pred_labels_c = (pred_label_c == labels).float()
batch_accuracy_a = pred_labels_a.sum().item() / len(labels_a)
batch_accuracy_b = pred_labels_b.sum().item() / len(labels_b)
batch_accuracy_c = pred_labels_c.sum().item() / len(labels)
# Update accuracy
sum_accuracy[split][0] += batch_accuracy_a
sum_accuracy[split][1] += batch_accuracy_b
sum_accuracy[split][2] += batch_accuracy_c
if (split=='train'):
progbar.update(j, ta=batch_accuracy_c)
if (progbar is not None):
progbar.finish()
# Compute epoch loss/accuracy
epoch_loss = {split: sum_loss[split] / len(loaders[split][0]) for split in ["train", "val", "test"]}
epoch_accuracy = {split: [sum_accuracy[split][i] / len(loaders[split][0]) for i in range(len(sum_accuracy[split])) ] for split in ["train", "val", "test"]}
# # Store params at the best validation accuracy
# if save_param and epoch_accuracy["val"] > best_val_accuracy:
# # torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
# torch.save(net.state_dict(), f"{model_name}_best_val.pth")
# best_val_accuracy = epoch_accuracy["val"]
print(f"Epoch {epoch + 1}:")
# Update history
for split in ["train", "val", "test"]:
history_loss[split].append(epoch_loss[split])
history_accuracy[split].append(epoch_accuracy[split])
# Print info
print(f"\t{split}\tLoss: {epoch_loss[split]:0.5}\tVGG 1:{epoch_accuracy[split][0]:0.5}"
f"\tVGG 2:{epoch_accuracy[split][1]:0.5}\tVGG *:{epoch_accuracy[split][2]:0.5}")
if save_param:
torch.save({'vgg_a':nets[0].state_dict(),'vgg_b':nets[1].state_dict(),'vgg_star':nets[2].state_dict(),'classifier':nets[3].state_dict()},f'{model_name}.pth')
For each epoch of training the result is this:
Then, I have a combined model which sums the weights of VGGA and VGGB:
DO = 'TEST'
if (DO=='TRAIN'):
train(nets, loaders, optimizer, criterion, epochs=50, dev=dev,save_param=True)
else:
state_dicts = torch.load('valerio.pth')
model1.load_state_dict(state_dicts['vgg_a']) #questi state_dict vengono dalla funzione di training
model2.load_state_dict(state_dicts['vgg_b'])
model3.load_state_dict(state_dicts['vgg_star'])
classifier.load_state_dict(state_dicts['classifier'])
test(model1,classifier,test_loader_all)
test(model2, classifier, test_loader_all)
test(model3, classifier, test_loader_all)
summed_state_dict = OrderedDict()
for key in state_dicts['vgg_star']:
if key.find('conv') >=0:
print(key)
summed_state_dict[key] = combo_fn(state_dicts['vgg_a'][key],state_dicts['vgg_b'][key])
else:
summed_state_dict[key] = state_dicts['vgg_star'][key]
model3.load_state_dict(summed_state_dict)
test(model3, classifier, test_loader_all)
where the test function is this:
def test(net,classifier, loader):
net.to(dev)
classifier.to(dev)
net.eval()
sum_accuracy = 0
# Process each batch
for j, (input, labels) in enumerate(loader):
input = input.to(dev)
labels = labels.float().to(dev)
features = net(input)
pred = torch.squeeze(classifier(features))
# https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
#pred_labels = (pred >= 0.0).long() # Binarize predictions to 0 and 1
_,pred_label = torch.max(pred, dim = 1)
pred_labels = (pred_label == labels).float()
batch_accuracy = pred_labels.sum().item() / len(labels)
# Update accuracy
sum_accuracy += batch_accuracy
epoch_accuracy = sum_accuracy / len(loader)
print(f"Accuracy after sum: {epoch_accuracy:0.5}")
And the result of this aggregation is the following:
I want to modify my training function in order to print the same things of the first image, plus the accuracy of the aggregated model (the highlighted part in red of the second picture). So basically, for each epoch, accuracies of VGGA, VGGB, VGG* and combined VGG, print these accuracies and continue with the training. I tried to add this model combo but I failed, because I did not able to insert into each epoch, but only at the end of the training. I was trying to add in the training function, between print(f"Epoch {epoch + 1}:")
and
# Update history
for split in ["train", "val", "test"]:
the code with the part of state_dict, but i am doing something wrong, i do not know what.
Can I reuse the code of the test, or I have to write new code?
Do you think i have to save the state_dict for each epoch, or i can do something else? Like model_c.parameters()=model_a.parameters()+model_b.parameters()
(which does not work, already tried)
I solved, here is the solution of how I modified my training function:
def train(nets, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
# try:
nets = [n.to(dev) for n in nets]
model_a = module_unwrap(nets[0], True)
model_b = module_unwrap(nets[1], True)
model_c = module_unwrap(nets[2], True)
reg_loss = nn.MSELoss()
criterion.to(dev)
reg_loss.to(dev)
# Initialize history
history_loss = {"train": [], "val": [], "test": []}
history_accuracy = {"train": [], "val": [], "test": []}
history_test = 0
# Store the best val accuracy
best_val_accuracy = 0
# Process each epoch
for epoch in range(epochs):
# Initialize epoch variables
sum_loss = {"train": 0, "val": 0, "test": 0}
sum_accuracy = {"train": [0,0,0], "val": [0,0,0], "test": [0,0,0]}
progbar = None
# Process each split
for split in ["train", "val", "test"]:
if split == "train":
for n in nets:
n.train()
widgets = [
' [', pb.Timer(), '] ',
pb.Bar(),
' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')
]
progbar = pb.ProgressBar(max_value=len(loaders[split][0]),widgets=widgets,redirect_stdout=True)
else:
for n in nets:
n.eval()
# Process each batch
for j,((input_a, labels_a),(input_b, labels_b)) in enumerate(zip(loaders[split][0],loaders[split][1])):
input_a = input_a.to(dev)
input_b = input_b.to(dev)
labels_a = labels_a.long().to(dev)
labels_b = labels_b.long().to(dev)
#print(labels_a.shape)
#labels_a = labels_a.squeeze()
#labels_b = labels_b.squeeze()
#labels_a = labels_a.unsqueeze(1)
#labels_b = labels_b.unsqueeze(1)
#print(labels_a.shape)
#labels_a = labels_a.argmax(-1)
#labels_b = labels_b.argmax(-1)
inputs = torch.cat([input_a,input_b],axis=0)
labels = torch.cat([labels_a, labels_b])
#labels = labels.squeeze()
#print(labels.shape)
#labels = labels.argmax(-1)
# Reset gradients
optimizer.zero_grad()
# Compute output
features_a = nets[0](input_a)
features_b = nets[1](input_b)
features_c = nets[2](inputs)
pred_a = torch.squeeze(nets[3](features_a))
pred_b = torch.squeeze(nets[3](features_b))
pred_c = torch.squeeze(nets[3](features_c))
loss = criterion(pred_a, labels_a) + criterion(pred_b, labels_b) + criterion(pred_c, labels)
for n in model_a:
layer_a = model_a[n]
layer_b = model_b[n]
layer_c = model_c[n]
if (isinstance(layer_a,nn.Conv2d)):
loss += lambda_reg * reg_loss(combo_fn(layer_a.weight,layer_b.weight),layer_c.weight)
if (layer_a.bias is not None):
loss += lambda_reg * reg_loss(combo_fn(layer_a.bias, layer_b.bias), layer_c.bias)
# Update loss
sum_loss[split] += loss.item()
# Check parameter update
if split == "train":
# Compute gradients
loss.backward()
# Optimize
optimizer.step()
# Compute accuracy
#https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
#pred_labels_a = (pred_a >= 0.0).long() # Binarize predictions to 0 and 1
#pred_labels_b = (pred_b >= 0.0).long() # Binarize predictions to 0 and 1
#pred_labels_c = (pred_c >= 0.0).long() # Binarize predictions to 0 and 1
#print(pred_a.shape)
_,pred_label_a = torch.max(pred_a, dim = 1)
pred_labels_a = (pred_label_a == labels_a).float()
_,pred_label_b = torch.max(pred_b, dim = 1)
pred_labels_b = (pred_label_b == labels_b).float()
_,pred_label_c = torch.max(pred_c, dim = 1)
pred_labels_c = (pred_label_c == labels).float()
batch_accuracy_a = pred_labels_a.sum().item() / len(labels_a)
batch_accuracy_b = pred_labels_b.sum().item() / len(labels_b)
batch_accuracy_c = pred_labels_c.sum().item() / len(labels)
# Update accuracy
sum_accuracy[split][0] += batch_accuracy_a
sum_accuracy[split][1] += batch_accuracy_b
sum_accuracy[split][2] += batch_accuracy_c
if (split=='train'):
progbar.update(j, ta=batch_accuracy_c)
if (progbar is not None):
progbar.finish()
# Compute epoch loss/accuracy
epoch_loss = {split: sum_loss[split] / len(loaders[split][0]) for split in ["train", "val", "test"]}
epoch_accuracy = {split: [sum_accuracy[split][i] / len(loaders[split][0]) for i in range(len(sum_accuracy[split])) ] for split in ["train", "val", "test"]}
# # Store params at the best validation accuracy
# if save_param and epoch_accuracy["val"] > best_val_accuracy:
# # torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
# torch.save(net.state_dict(), f"{model_name}_best_val.pth")
# best_val_accuracy = epoch_accuracy["val"]
print(f"Epoch {epoch + 1}:")
# Update history
for split in ["train", "val", "test"]:
history_loss[split].append(epoch_loss[split])
history_accuracy[split].append(epoch_accuracy[split])
# Print info
print(f"\t{split}\tLoss: {epoch_loss[split]:0.5}\tVGG 1:{epoch_accuracy[split][0]:0.5}"
f"\tVGG 2:{epoch_accuracy[split][1]:0.5}\tVGG *:{epoch_accuracy[split][2]:0.5}")
if save_param:
torch.save({'vgg_a':nets[0].state_dict(),'vgg_b':nets[1].state_dict(),'vgg_star':nets[2].state_dict(),'classifier':nets[3].state_dict()},f'{model_name}.pth')
test(nets[0], nets[3], test_loader_all)
test(nets[1], nets[3], test_loader_all)
test(nets[2], nets[3], test_loader_all)
summed_state_dict = OrderedDict()
for key in nets[2].state_dict():
if key.find('conv') >=0:
#print(key)
summed_state_dict[key] = combo_fn(nets[0].state_dict()[key],nets[1].state_dict()[key])
else:
summed_state_dict[key] = nets[2].state_dict()[key]
nets[2].load_state_dict(summed_state_dict)
test(nets[2], nets[3], test_loader_all)
The edited parts are the last rows.