Search code examples
machine-learningkerasjax

Simultaneously going over different kinds of data with Keras training


In a regression task I'm given the following data:

  1. Input vectors with a known label. MSE loss should be used between the precidtion and the label.
  2. Pairs of input vectors without a label, for which it is known that the model should give similar results. MSE loss should be used between the two predictions.

What's the right way to fit a Keras model with these two kinds of data simultaneously?

Ideally, I'd like the train loop to iterate the two kinds in an interleaved way - a superivsed (1) batch and then a self-supervised (2) batch, then supervised again etc.

If it matters, I'm using the Jax backend. Keras version 3.2.1.


Solution

  • I eventually found a trick that solved it for my case without too many customizations.

    But if you do need to pass different kinds of data for training, I don't think there's an easy answer as for today.

    It should be possible though to write your own training loop, and use any structure that you want for the data and labels. In this case you might also want to use the trainer pattern, implementing a custom version of keras.src.backend.jax.trainer.JAXTrainer.