Search code examples
pythonfunctionlambdamean

pass custom scaling operation in python


i am following an example of the https://github.com/google/lightweight_mmm but instead of using the default setting for scalars, which is mean:

media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

i need to use the lambda function:

lambda x: jnp.mean(x[x > 0])

How can this be done? I tried couple of things, but since i am a complete beginner, i feel lost.

So i have tried:

lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=x)

and

lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=lambda)

None of these work.


Solution

  • This should do it

    div = lambda x: jnp.mean(x[x > 0])
    media_scaler = preprocessing.CustomScaler(divide_operation=div)