I'm writing a wrapper for pytorch transformers. To keep it simple I will include a minimal example. A class Parent, that will be an abstract class for classes My_BERT(Parent) and My_GPT2(Parent). Because LM models for model_Bert and model_gpt2 are included in pytorch, they have many similar functions, thus I want to minimize code redundancy by coding otherwise identical functions in Partent.
My_bert and My_gpt2 differ basically with the model initialization, and one argument passed to model, but 99% of functions use both models in identical way.
The problem is with function "model" that accepts different arguments:
minmal code example:
class Parent():
""" My own class that is an abstract class for My_bert and My_gpt2 """
def __init__(self):
pass
def fancy_arithmetic(self, text):
print("do_fancy_stuff_that_works_identically_for_both_models(text=text)")
def compute_model(self, text):
return self.model(input_ids=text, masked_lm_labels=text) #this line works for My_Bert
#return self.model(input_ids=text, labels=text) #I'd need this line for My_gpt2
class My_bert(Parent):
""" My own My_bert class that is initialized with BERT pytorch
model (here model_bert), and uses methods from Parent """
def __init__(self):
self.model = model_bert()
class My_gpt2(Parent):
""" My own My_gpt2 class that is initialized with gpt2 pytorch model (here model_gpt2), and uses methods from Parent """
def __init__(self):
self.model = model_gpt2()
class model_gpt2:
""" This class mocks pytorch transformers gpt2 model, thus I'm writing just bunch of code that allows you run this example"""
def __init__(self):
pass
def __call__(self,*input, **kwargs):
return self.model( *input, **kwargs)
def model(self, input_ids, labels):
print("gpt2")
class model_bert:
""" This class mocks pytorch transformers bert model"""
def __init__(self):
pass
def __call__(self, *input, **kwargs):
self.model(*input, **kwargs)
def model(self, input_ids, masked_lm_labels):
print("bert")
foo = My_bert()
foo.compute_model("bar") # this works
bar = My_gpt2()
#bar.compute_model("rawr") #this does not work.
I know I can override Parent::compute_model
function inside My_bert
and My_gpt2
classes.
BUT since both "model" methods are so similar, I wonder if there is a way to say: " I'll pass you three arguments, you can use those that you know"
def compute_model(self, text):
return self.model(input_ids=text, masked_lm_labels=text, labels=text) # ignore the arguments you dont know
*args
and **kwargs
should take care of the issue you are running into.
In your code, you will modify compute_model
to take the arbitrary arguments
def compute_model(self, *args, **kwargs):
return self.model(*args, **kwargs)
Now the arguments will be defined by the model
method on the different classes
With this change the following should work:
foo = My_bert()
foo.compute_model("bar", "baz")
bar = My_gpt2()
bar.compute_model("rawr", "baz")
If you are not familiar with args and kwargs, they allow you to pass arbitrary arguments to a function. args will take unnamed parameters and pass them in the order them are received to the function kwargs or keyword arguments takes named arguments and passed them to the correct parameter. So the following will also work:
foo = My_bert()
foo.compute_model(input_ids="bar", masked_lm_labels="baz")
bar = My_gpt2()
bar.compute_model(input_ids="rawr", labels="baz")
Just a note the names args and kwargs are meaningless, you can name them anything, but the typical convention is args and kwargs