Search code examples
pythonkerasloss-function

Keras Loss function using outputs of batch


I'm trying to learn a joint embedded representation of images and text using a two-branch neural network using keras. This is how my model looks like: model structure

These are the current in- and outputs of my training model:

model = Model([txt_input,img_input], [encoded_txt, encoded_img])

I have to use a bidirectional ranking loss, which means that the representations of corresponding text and images should be closer to each other than any other image/text by a margin m. This is the entire loss function, with

  • s : a similarity function
  • D : the training set
  • Yi+ : The set of corresponding (positive) text descriptions given image xi (only one positive in my experiment)
  • Yi- : Set of non-corresponding (negative) descriptions given image xi
  • Xi+ : The set of corresponding (positive) images given text description yi (only one positive in my experiment)
  • Xi- : Set of non-corresponding (negative) images given text description yi

Loss formula

The problem is that in order to compute this loss for an example, I not only have to know the output of the current image and corresponding text representation, but I have to compute their similarity to the represenations of other images/texts.

Concretely, my question is: Is there a way to include the outputs of the entire batch, or at least the previous n samples, when calculating the loss?

The only way I see how to do such a thing is to create a loss function with a sort of state that keeps the representations of the last n samples, and use these to compute the similarities. I don't think that is a good solution and was wondering if there was a more elegant way to implement this. I'm also looking into other frameworks such as Pytorch to check if they support something like batchwise losses. Any help would be greatly appreciated.

Thank you!

PS: I'm actually trying to reproduce the experiment of this paper:

L. Wang, Y. Li, and S. Lazebnik, “Learning deep structure-preserving image- text embeddings,” in Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 5005–5013, 2016.

The image has been extracted from this paper too.


Solution

  • Concretely, my question is: Is there a way to include the outputs of the entire batch, or at least the previous n samples, when calculating the loss?

    I think you are wrongly phrasing your question or have a wrong idea. Let's say, you set batch size = 8 when training, your loss function gets exactly the whole batch and the loss is calculated on the batch.

    Check the keras losses implementation.

    class LossFunctionWrapper(Loss):
        """Wraps a loss function in the `Loss` class.
        # Arguments
            fn: The loss function to wrap, with signature `fn(y_true, y_pred,
                **kwargs)`.
            reduction: (Optional) Type of loss reduction to apply to loss.
                Default value is `SUM_OVER_BATCH_SIZE`.
            name: (Optional) name for the loss.
            **kwargs: The keyword arguments that are passed on to `fn`.
        """
    

    Default value is SUM_OVER_BATCH_SIZE.

    So, you can calculate the loss on the whole batch.

    Additionally, you can use the concept of triplet loss and generate positive and negative samples with a flag in the loss function to make the calculation easy.

    Finally, here is a tensorflow implementation of the paper which may help: https://github.com/lwwang/Two_branch_network