Search code examples
pythontheano

How to write update in theano function


I am new to theano so maybe this is a simple question. If I have a function

f = theano.function(
          inputs=[x],
          outputs=[y],
          updates=update)

and y depends on w that I want to update using

w = w + tr_rate * (pos_associations-neg_associations)

I can write

wparameters = [w]
update = [(wparam, 
           wparam + tr_rate * (pos_associations-neg_associations)) for wparam in wparameters]

and it will update the function f using the update rule.

But if y depends on another variable, say z, that I want to update using a different rule, say

z = z + tr_rate*(x - vis)

How do I combine the two rules together?


Solution

  • I found my own answer and I am posting it if it can help other people. You can create a variable update and then use the .append function to define new rules.

    So, instead of

    wparameters = [w]
    update = [(wparam, 
               wparam + tr_rate * (pos_associations-neg_associations)) for wparam in wparameters]
    

    you can append a new rule and write:

    wparameters  = [w]
    zparameters = [z]
    
    
    update = []
    for wparam, zparam in zip(wparameters, bparameters):
        update.append((wparam, wparam + tr_rate*(pos_associations - neg_associations)))
        update.append((zparam, zparam + tr_rate*(x - vis))