Search code examples

Embedding Custom Functions into NN

I'm currently asking myself how to build a model with a couple of extra functions. I got an entity of custom functions, and I want to embed them as layers into my model (NN). For that I'm using TF 2.0. but I'm currently struggling to do that. All I find is answers about activation functions, but that's not what I'm looking for.

A custom function returns something like a+b or any other algorithm (matrix multiplication etc.) What we can say is, I have one layer to another one, and want to embed my custom function in between those two layers like so:

input -> dense layer 1 -> custom function 1 -> custom function 2 -> dense layer 2

I'm going to say that the activation function from one layer to another is the custom function. But what if my custom function takes two inputs? Or I have two functions I want to process my input in before I pass it to the next function?

Another way to solve that problem: Let's say I got my custom functions cm*, and my layers l*; what I do is build a model for each layer I want to put in between two custom functions

cm1 -> model(l1) -> cm2 -> model(l2,l3) -> cm3 -> cm4 -> model(l4) -> ....

but wouldn't it be stupid to build a model for each of those trajectories? And what about the loss? The back propagation of residual connected layers is something else than having a lot of models and functions layered together. Or am I wrong?


  • I'm not sure about TF 2.0, but in Keras you can build your own custom layers that can receive multiple inputs by overriding the Layer class. See for more details. The link doesn't explain how to pass in multiple inputs to a layer, but all you have to do is to call the layer with a list of inputs and unpack them inside the call function, something like this:

    class MyCustomLayer(tf.keras.Layer):
        def __init__(self):
            # your code here
        def call(self, inputs): # example call: MyCustomLayer()([1, 2])
            x, y = inputs
            # your code here
            output = x + y # placeholder
            return output