Search code examples
tensorflowkerasmixture-modeltensorflow-probability

Keras Model using Tensorflow Distribution for loss fails with batch size > 1


I'm trying to use a distribution from tensorflow_probability to define a custom loss function in Keras. More specifically, I'm trying to build a Mixture Density Network.
My model works on a toy dataset when batch_size = 1 (it learns to predict the correct mixture distribution for y using x). But it "fails" when batch_size > 1 (it predicts the same distribution for all y, ignoring x). This makes me think my problem has to do with batch_shape vs. sample_shape.

To reproduce:

import random

import keras
from keras import backend as K
from keras.layers import Dense, Activation, LSTM, Input, Concatenate, Reshape, concatenate, Flatten, Lambda
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from keras.models import Sequential, Model
import tensorflow
import tensorflow_probability as tfp
tfd = tfp.distributions

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# generate toy dataset
random.seed(12902)
n_obs = 20000
x = np.random.uniform(size=(n_obs, 4))
df = pd.DataFrame(x, columns = ['x_{0}'.format(i) for i in np.arange(4)])

# 2 latent classes, with noisy assignment based on x_0, x_1, (x_2 and x_3 are noise)
df['latent_class'] = 0
df.loc[df.x_0 + df.x_1 + np.random.normal(scale=.05, size=n_obs) > 1, 'latent_class'] = 1
df.latent_class.value_counts()

# Latent class will determines which mixture distribution we draw from
d0 = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]),
    components_distribution=tfd.Normal(
      loc=[-1., 1], scale=[0.1, 0.5])) 
d0_samples = d0.sample(sample_shape=(df.latent_class == 0).sum()).numpy()

d1 = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.5, 0.5]),
    components_distribution=tfd.Normal(
      loc=[-2., 2], scale=[0.2, 0.6]))

d1_samples = d1.sample(sample_shape=(df.latent_class == 1).sum()).numpy()

df.loc[df.latent_class == 0, 'y'] = d0_samples
df.loc[df.latent_class == 1, 'y'] = d1_samples

fig, ax = plt.subplots()
bins = np.linspace(-4, 5, 9*4 + 1)
df.y[df.latent_class == 0].hist(ax=ax, bins=bins, label='Class 0', alpha=.4, density=True)
df.y[df.latent_class == 1].hist(ax=ax, bins=bins, label='Class 1', alpha=.4, density=True)
ax.legend();

# mixture density network
N_COMPONENTS = 2  # number of components in the mixture
input_feature_space = 4

flat_input = Input(shape=(input_feature_space,), 
                   batch_shape=(None, input_feature_space), 
                   name='inputs')
x = Dense(6, activation='relu', 
             kernel_initializer='glorot_uniform',
             bias_initializer='ones')(flat_input)
x = Dense(6, activation='relu',
             kernel_initializer='glorot_uniform',
             bias_initializer='ones')(x)
# 3 params per component: weight, loc, scale
output = Dense(N_COMPONENTS*3,
             kernel_initializer='glorot_uniform',
             bias_initializer='ones')(x)

model = Model(inputs=[flat_input],
              outputs=[output])

I suspect the problem is in the next 3 functions:


def get_mixture_coef(output, num_components):
    """
    Extract mixture params from output
    """
    out_pi = output[:, :num_components]
    out_sigma = output[:, num_components:2*num_components]
    out_mu = output[:, 2*num_components:]
    # use softmax to normalize pi into prob distribution
    max_pi = K.max(out_pi, axis=1, keepdims=True)
    out_pi = out_pi - max_pi
    out_pi = K.exp(out_pi)
    normalize_pi = 1 / K.sum(out_pi, axis=1, keepdims=True)
    out_pi = normalize_pi * out_pi
    # use exp to ensure sigma is pos
    out_sigma = K.exp(out_sigma)
    return out_pi, out_sigma, out_mu

def get_lossfunc(out_pi, out_sigma, out_mu, y):
    d0 = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            probs=out_pi),
        components_distribution=tfd.Normal(
          loc=out_mu, scale=out_sigma,
        ),
    )
    # I suspect the problem is here
    return -1 * d0.log_prob(y)

def mdn_loss(num_components):
    def loss(y_true, y_pred):
        out_pi, out_sigma, out_mu = get_mixture_coef(y_pred, num_components)
        return get_lossfunc(out_pi, out_sigma, out_mu, y_true)
    return loss

opt = Adam(lr=.001)
model.compile(
    optimizer=opt,
    loss = mdn_loss(N_COMPONENTS),
)

es = EarlyStopping(monitor='val_loss',
                  min_delta=1e-5,
                  patience=5,
                  verbose=1, mode='auto')

validation = .15
validate_idx = np.random.choice(df.index.values, 
                                size=int(validation * df.shape[0]), 
                                replace=False)
train_idx = [i for i in df.index.values if i not in validate_idx]

x_cols = ['x_0', 'x_1', 'x_2', 'x_3']

model.fit(x=df.loc[train_idx, x_cols].values,
          y=df.loc[train_idx, 'y'].values[:, np.newaxis],
          validation_data=(
            df.loc[validate_idx, x_cols].values,
            df.loc[validate_idx, 'y'].values[:, np.newaxis]), 
          # model works when batch_size = 1
          # model fails when batch_size > 1
          epochs=2, batch_size=1, verbose=1, callbacks=[es])

def sample(output, n_samples, num_components):
    """Sample from a mixture distribution parameterized by
    model output."""
    pi, sigma, mu = get_mixture_coef(output, num_components)    
    d0 = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        probs=pi),
    components_distribution=tfd.Normal(
      loc=mu,
      scale=sigma)) 
    return d0.sample(sample_shape=n_samples).numpy()

yhat = model.predict(df.loc[train_idx, x_cols].values)

out_pi, out_sigma, out_mu = get_mixture_coef(yhat, 2)

latent_1_samples = sample(yhat[:1], n_samples=1000, num_components=2)
latent_1_samples = pd.DataFrame({'latent_1_samples': latent_1_samples.ravel()})

fig, ax = plt.subplots()
bins = np.linspace(-4, 5, 9*4 + 1)
latent_1_samples.latent_1_samples.hist(ax=ax, bins=bins, label='Class 1: yHat', alpha=.4, density=True)
df.y[df.latent_class == 0].hist(ax=ax, bins=bins, label='Class 0: True', density=True, histtype='step')
df.y[df.latent_class == 1].hist(ax=ax, bins=bins, label='Class 1: True', density=True, histtype='step')
ax.legend();

Thanks in advance!

Update

I found two ways to solve the problem, guided by this answer. Both solutions point to the fact that Keras is awkwardly broadcasting y to match y_pred:

def get_lossfunc(out_pi, out_sigma, out_mu, y):
    d0 = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(
            probs=out_pi),
        components_distribution=tfd.Normal(
          loc=out_mu, scale=out_sigma,
        ),
    )
    # this also works: 
    # return -1 * d0.log_prob(tensorflow.transpose(y))
    return -1 * d0.log_prob(y[:, 0])

Solution

  • Specifying the workaround here (Answer Section) even though it is specified by Dan in the question, for the benefit of the Community.

    The problem of predicting the same distribution for all y, ignoring x can be resolved in two ways.

    Code for Solution 1 is mentioned below:

    def get_lossfunc(out_pi, out_sigma, out_mu, y):
        d0 = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(
                probs=out_pi),
            components_distribution=tfd.Normal(
              loc=out_mu, scale=out_sigma,
            ),
        )
        return -1 * d0.log_prob(tensorflow.transpose(y))
    

    Code for Solution 2 is mentioned below:

    def get_lossfunc(out_pi, out_sigma, out_mu, y):
        d0 = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(
                probs=out_pi),
            components_distribution=tfd.Normal(
              loc=out_mu, scale=out_sigma,
            ),
        )
        return -1 * d0.log_prob(y[:, 0])
    

    Hope this helps. Happy Learning!