Search code examples
pythonpython-3.xneural-networkcntk

CNTK: conditional execution


Is it possible to create a "conditional" network in CNTK and apply it only on one of the inputs depending on another input variable? See the following code:

a_in = ct.input_variable(shape=[16,16])
b_in = ct.input_variable(shape=[16,16])
flag = ct.input_variable(shape=[])

a_branch = ct.layers.Sequential([...])
b_branch = ct.layers.Sequential([...])

sel_branch = ct.element_select(flag, a_branch, b_branch)

out = sel_branch(a_in, b_in)

Howerer, this doesn't work since sel_branch expects 3 arguments instead of the ones requested either by a_branch or b_branch (which is totally correct since here I am using element_select in a wrong way)

Keep in mind that the objective is to avoid executing both branches,


Solution

  • The answer is no, at this moment there is no conditional execution in CNTK. The general case is that flag is a vector/tensor and some of its elements would be 0 and others would be 1. There's an obvious optimization when all the elements have the same value but it is not implemented. However even if it were implemented the signature of sel_branch would still be that it requires 3 arguments, because that is a "compile-time" property, while the aforementioned optimization can only be determined at runtime. Even in your case when flag is a scalar, it might be 0 in one batch and 1 for the other and the signature of sel_branch cannot change from batch to batch.