Search code examples
jaxgoogle-jax

Jax.lax.scan with arguments?


I'm trying to speed up the execution of my code rewriting for loops into jax.lax.scan, but I ran into the issue that I need the scanFunction to handle parameters passed to the main function - but how to do it?

Here I get NameError: name 'coefs' is not defined. from within ParseRow, but of course this is natural since I didn't pass coefs into ParseRow.

from functools import partial
import jax.lax
import jax.numpy as jnp
import jaxopt
import numpy as np

dataList = [[1, 1.36, 3.41, 5, 7, 2, 6, 12, 5, 10, 1, 7, 1, 12, 10, 12, 4, 10, 12, 7, 11, 5, 10, 3, 6, 12, 6, 5, 3, 5, 9, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0], [1, 1.45, 2.77, 5, 7, 2, 6, 12, 5, 10, 1, 7, 1, 12, 10, 12, 4, 10, 12, 7, 11, 5, 10, 3, 6, 12, 6, 5, 3, 5, 9, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0], [1, 1.73, 2.21, 5, 7, 2, 6, 12, 5, 10, 1, 7, 1, 12, 10, 12, 4, 10, 12, 7, 11, 5, 10, 3, 6, 12, 6, 5, 3, 5, 9, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0]]
dataTuple = tuple(tuple(sub) for sub in dataList)
start = np.random.sample(532)

def ParseRow (carry, row):
    balance = 0.0

    result = row[0]
    fOdds = row[1]
    dOdds = row[2]

    hCoefs = jnp.array(np.zeros(12))
    
    for p in range (0, 14):
        s = (row[3+p]-1).astype(int)
        h = (row[17+p]).astype(int)
        r = (row[43+p]).astype(int)
                
        bCoef = coefs[p]
        sCoef = coefs[14 + (p * 12) + s]
        hCoef = coefs[182 + (p * 12) + h]
            
        rCoef = jax.lax.cond (r == 1.0, lambda x: x, lambda x: 1.0, coefs[350 +p] )
                
        pStrength = bCoef * sCoef * hCoef * rCoef

        hCoefs = hCoefs.at[h].set(hCoefs[h] + pStrength)

    for h in range (0, 12):
        hSign = (row[31+h]-1).astype(int)
        hCoefs = hCoefs.at[h].set(hCoefs[h] * coefs [364 + (h*12) + hSign]) 
                            
    fPoints = 0.0
    dPoints = 0.0
            
    for h in range (0, 12):
        fPoints += hCoefs[h] * coefs[508+h]
        dPoints += hCoefs[h] * coefs[520+h]
    
    balance += jax.lax.cond (result == 1.0, goodPrediction, lambda x: 0.0, [fOdds, fPoints, dPoints])
    balance += jax.lax.cond (result == 2.0, goodPrediction, lambda x: 0.0, [dOdds, dPoints, fPoints])
        
    return carry + balance, balance

def goodPrediction (args):
    return args[0] * 1000 * nn.sigmoid(args[1] - args[2])
                        
def MyFuncJax(coefs, data):
    balance = float(len(data)*-1000)
    
    dataJnp = jnp.asarray(data)

    balance += jax.lax.scan(ParseRow, 0, dataJnp)[0]
    

    return balance
    
        
mini = jaxopt.GradientDescent(MyFuncJax, stepsize=0.001, maxiter=5000, verbose=0, jit=True)

@partial(jax.jit, static_argnums = (1,) )
def jitted_run(start, dataTuple):
    return mini.run(start, dataTuple)

jitted_run(start, dataTuple)

Solution

  • Generally you can use the carry value to pass along extra data to the body function. So for example you could do something like this:

    def ParseRow (carry, row):
        coefs, total = carry
        # ...
        return (coefs, total + balance), balance
    

    and then construct the scan call something like this:

    (coefs, total), result = jax.lax.scan(ParseRow, (coefs, 0), dataJnp)
    balance += total