Search code examples
pythonmathdeep-learningpytorch

PyTorch Factorial Function


There does not seem to be a PyTorch function for computing a factorial. Is there a method to do this in PyTorch? I am looking to manually compute a Poisson distribution in Torch (I am aware this exists: https://pytorch.org/docs/stable/generated/torch.poisson.html) and the formula requires a factorial in the denominator.

Poisson Distribution: https://en.wikipedia.org/wiki/Poisson_distribution


Solution

  • I think you can find it as torch.jit._builtins.math.factorial BUT pytorch as well as numpy and scipy (Factorial in numpy and scipy) uses python's builtin math.factorial:

    import math
    
    import numpy as np
    import scipy as sp
    import torch
    
    
    print(torch.jit._builtins.math.factorial is math.factorial)
    print(np.math.factorial is math.factorial)
    print(sp.math.factorial is math.factorial)
    
    True
    True
    True
    

    But, in contrast, scipy in addition to "mainstream" math.factorial contains the very "special" factorial function scipy.special.factorial. Unlike function from math module it operates on arrays:

    from scipy import special
    
    print(special.factorial is math.factorial)
    
    False
    
    # the all known factorial functions
    factorials = (
        math.factorial,
        torch.jit._builtins.math.factorial,
        np.math.factorial,
        sp.math.factorial,
        special.factorial,
    )
    
    # Let's run some tests
    tnsr = torch.tensor(3)
    
    for fn in factorials:
        try:
            out = fn(tnsr)
        except Exception as err:
            print(fn.__name__, fn.__module__, ':', err)
        else:
            print(fn.__name__, fn.__module__, ':', out)
    
    factorial math : 6
    factorial math : 6
    factorial math : 6
    factorial math : 6
    factorial scipy.special._basic : tensor(6., dtype=torch.float64)
    
    tnsr = torch.tensor([1, 2, 3])
    
    for fn in factorials:
        try:
            out = fn(tnsr)
        except Exception as err:
            print(fn.__name__, fn.__module__, ':', err)
        else:
            print(fn.__name__, fn.__module__, ':', out)
    
    factorial math : only integer tensors of a single element can be converted to an index
    factorial math : only integer tensors of a single element can be converted to an index
    factorial math : only integer tensors of a single element can be converted to an index
    factorial math : only integer tensors of a single element can be converted to an index
    factorial scipy.special._basic : tensor([1., 2., 6.], dtype=torch.float64)