Search code examples
tensorflowconstraintsdropout

how to add constraint to make sure weights toward to 0 as much as possible in tensorflow?


say we have a simple neural network with 4 Dense layers, Lin -> L1 -> L2 -> Lout; assume L2 = matrix[1x5] and the 5 values can be represented as [a1, a2, a3, a4, a5]; when we train the model, we know there are lots of groups of [a1, a2, a3, a4, a5] satisfying the data like [1,2,3,4,5] [1,0,4,5,5] [0,0,15,0,0] [0,0,0,5,0];

my question is how to add a constraint to the layer weights so that we can make sure most of them are 0. for example, the 4 groups L2 weights [1,2,3,4,5] [1,0,4,5,5] [0,0,15,0,0] [0,0,0,5,0], where the 3rd and 4th one has 4 zeros; and 5 < 15 so that we treat the 4th one as the most prior among the 4 groups.

we know TensorFlow Keras has the functionality: https://keras.io/api/layers/constraints/

but there are no built-in constraints for my question. any idea on how to write such a constraint or maybe there is another way to do this?

more specific, we have lots of vectors and we want to classify the vectors, we want a layer to recognize which columns are important (but we do not know exact columns, like word embedding, we need to transform a word to vector; here we need to transform a vector to importance bitmask and then do further processing) and we can drop out other columns. for example, we have features [x1, x2, x3, x4, x5] and we got L2 [0,0,0,5,0], then we can say, the 4th column is important so that we can transform the feature vector to [0, 0, 0, 5 * x4, 0]

thx in advance.


Solution

  • so that we can make sure most of them are 0

    if there is no strict requirements to the number of 0s (as you might have suggested in the single-column example) you are looking for Lasso regression (so called L1 regularization) which, to simply put, penalizes the magnitude of each weight. The weight will only be big if it is absolutely crucial for the inference.

    In tensorflow 2.x this can be done via kernel regularizer. Now, this enforces weights to be small, but it does not guarantee it will be 0. Furthermore, it strongly affects performance if used abusively.

    As a side note, the problem you are probably trying to solve is related to machine learning interpretability/explainability, and while your approach is interesting, it might be worth looking at methods/models constructed solely for this purpose (there are models that are able to produce feature significance etc)