Hey there I am trying to create a basic Sentence Transformer model for few shot learning, however while fitting I observed that the changes made to the model are miniscule because the model has been trained on 1B+ pairs whereas I train it on around 40 pairs per epochs, to deal with this problem I decided to apply a linear layer on top of the sentence transformer in order to learn the embeddings corresponding to a specific data set. However there seems to be no forward function for the sentence transformers. Their is an alternative with the model.encode() method but it does not change the model parameters. So summarizing I want to create a network that does a forward pass on the sentence transformer, then on the linear layer and then finally get a loss which can be used across the model. Any help would be useful. Thank you.
Here is a simple code snippet that adds one simple linear layer on top of a sentence transformer:
import torch
from sentence_transformers import SentenceTransformer
class SentenceTransformerWithLinearLayer(torch.nn.Module):
def __init__(self, transformer_model_name):
super(SentenceTransformerWithLinearLayer, self).__init__()
# Load the sentence transformer model
self.sentence_transformer = SentenceTransformer(transformer_model_name)
last_layer_dimension = self.sentence_transformer.get_sentence_embedding_dimension()
# New linear layer with 16 output dimensions
self.linear = torch.nn.Linear(last_layer_dimension, 16)
def forward(self, x):
# Pass the input through the sentence transformer
x = self.sentence_transformer.encode(x, convert_to_numpy=False).unsqueeze(0)
# Pass through the linear layer
x = self.linear(x)
return x
This can than be used similarly to a simple sentence transformer. In this example I loaded the all-mpnet-base-v2
model as the base sentence transformer. The input of "Hello world"
is passed through the sentence transformer and then the linear layer, resulting in a 16 dimensional vector.
model = SentenceTransformerWithLinearLayer("all-mpnet-base-v2")
output = model.forward("Hello world")
This vector can then be used in a loss function e.g. a MSELoss
loss_function = torch.nn.MSELoss()
...
expected = ...
loss = loss_function(output, expected)
loss.backward()
...