As we know, optional arguments must be at the end of the arguments list, like below:
def func(arg1, arg2, ..., argN=default)
I saw some exceptions in the PyTorch
package. For example, we can find this issue in torch.randint
. As it is shown, it has a leading optional argument in its positional arguments! How could be possible?
Docstring:
randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
How can we define a function in a similar way as above?
A single function is not allowed to have only leading optional parameters:
8.6. Function definitions
[...] If a parameter has a default value, all following parameters up until the “*” must also have a default value — this is a syntactic restriction that is not expressed by the grammar.
Note this excludes keyword-only parameters, which never receive arguments by position.
If desired, one can emulate such behaviour by manually implementing the argument to parameter matching. For example, one can dispatch based on arity, or explicitly match variadic arguments.
def leading_default(*args):
# match arguments to "parameters"
*_, low, high, size = 0, *args
print(low, high, size)
leading_default(1, 2) # 0, 1, 2
leading_default(1, 2, 3) # 1, 2, 3
A simple form of dispatch achieves function overloading by iterating signatures and calling the first matching one.
import inspect
class MatchOverload:
"""Overload a function via explicitly matching arguments to parameters on call"""
def __init__(self, base_case=None):
self.cases = [base_case] if base_case is not None else []
def overload(self, call):
self.cases.append(call)
return self
def __call__(self, *args, **kwargs):
failures = []
for call in self.cases:
try:
inspect.signature(call).bind(*args, **kwargs)
except TypeError as err:
failures.append(str(err))
else:
return call(*args, **kwargs)
raise TypeError(', '.join(failures))
@MatchOverload
def func(high, size):
print('two', 0, high, size)
@func.overload
def func(low, high, size):
print('three', low, high, size)
func(1, 2, size=3) # three 1 2 3
func(1, 2) # two 0 1 2
func(1, 2, 3, low=4) # TypeError: too many positional arguments, multiple values for argument 'low'