Is it possible to create a "conditional" network in CNTK and apply it only on one of the inputs depending on another input variable? See the following code:
a_in = ct.input_variable(shape=[16,16])
b_in = ct.input_variable(shape=[16,16])
flag = ct.input_variable(shape=[])
a_branch = ct.layers.Sequential([...])
b_branch = ct.layers.Sequential([...])
sel_branch = ct.element_select(flag, a_branch, b_branch)
out = sel_branch(a_in, b_in)
Howerer, this doesn't work since sel_branch
expects 3 arguments instead of the ones requested either by a_branch
or b_branch
(which is totally correct since here I am using element_select
in a wrong way)
Keep in mind that the objective is to avoid executing both branches,
The answer is no, at this moment there is no conditional execution in CNTK. The general case is that flag is a vector/tensor and some of its elements would be 0 and others would be 1. There's an obvious optimization when all the elements have the same value but it is not implemented. However even if it were implemented the signature of sel_branch
would still be that it requires 3 arguments, because that is a "compile-time" property, while the aforementioned optimization can only be determined at runtime. Even in your case when flag is a scalar, it might be 0 in one batch and 1 for the other and the signature of sel_branch
cannot change from batch to batch.