Search code examples
tensorflowmachine-learningneural-networktensorflow2.0loss-function

Passing MetaData to Custom Loss Function


I'd like to create a custom loss function that depends on metadata. In the simplest form I'd like to multiply the loss by a per batch weight (determined by the meta data).

For simplicity, consider passing the desired weight directly. Here are two attempts at loss functions:

def three_arg_loss(loss_func):
    """ a loss function that takes 3 args"""
    def _loss(target,output,weight):
        return weight*loss_func(target,output)
    return _loss

def target_list_loss(loss_func):
    """ a loss function that expects the target arg to be [target,weight]"""
    def _loss(target,output):
        weight=target[1]
        target=target[0]
        return weight*loss_func(target,output)
    return _loss

When I tried to train I got the following:

  • three_arg_loss: TypeError: tf___loss() missing 1 required positional argument: 'weight'

But of course, I tripled checked and I was indeed passing 3 args

  • target_list_loss: ValueError: Shapes (None, None, None) and (None, None, None, 4) are incompatible

And again after triple checking, I was indeed passing [target,weight] as the target argument. I was worried here I might have messed up the order of the arguments to the loss function so i flipped them just to be sure and got ValueError: Shapes (None, None, 4) and (None, None, None, None) are incompatible

Thoughts? Whats the correct/best approach to have a loss function that is dependent on additional data (in my case geographic location)?

As requested below here is a complete (but silly) example showing the errors

BATCH_SIZE=2
SIZE=3
STEPS=8
EPOCHS=3
NB_CLASSES=4


def gen_inpt(ch_in):
    return tf.random.uniform((BATCH_SIZE,SIZE,SIZE,ch_in))

def gen_targ(nb_classes):
    t=tf.random.uniform((BATCH_SIZE,SIZE,SIZE),maxval=nb_classes,dtype=tf.int32)
    return tf.keras.utils.to_categorical(t,num_classes=nb_classes)

def gen(ch_in,ch_out):
    return ( ( gen_inpt(ch_in), gen_targ(ch_out) ) for b in range(BATCH_SIZE*STEPS*EPOCHS) )

def gen_targ_list(ch_in,ch_out):
    return ( ( gen_inpt(ch_in), [gen_targ(ch_out), tf.fill(1,2222)] ) for b in range(BATCH_SIZE*STEPS*EPOCHS) )

def gen_3args(ch_in,ch_out):
    return ( ( gen_inpt(ch_in), gen_targ(ch_out), tf.fill(1,10000.0) ) for b in range(BATCH_SIZE*STEPS*EPOCHS) )


class Toy(tf.keras.Model):
    
    def __init__(self,nb_classes):
        super(Toy, self).__init__()
        self.l1=layers.Conv2D(32,3,padding='same')
        self.l2=layers.Conv2D(nb_classes,3,padding='same')
        
    def call(self,x):
        x=self.l1(x)
        x=self.l2(x)
        return x

def test_loss(loss_func):
    def _loss(target,output):
        return loss_func(target,output)
    return _loss


def target_list_loss(loss_func):
    def _loss(target,output):
        weight=target[1]
        target=target[0]
        return weight*loss_func(target,output)
    return _loss


def three_arg_loss(loss_func):
    def _loss(target,output,weight):
        return weight*loss_func(target,output)
    return _loss


loss_func=tf.keras.losses.CategoricalCrossentropy()

loss_test=test_loss(loss_func)
loss_targ_list=target_list_loss(loss_func)
loss_3arg=three_arg_loss(loss_func)

def test_train(loss,gen):
    try: 
        model=Toy(NB_CLASSES)    
        model.compile(optimizer='adam',
                  loss=loss,
                  metrics=['accuracy'])
        model.fit(gen(6,NB_CLASSES),steps_per_epoch=STEPS,epochs=EPOCHS)
    except Exception as e:
        print(e)

#
# RUN TESTS
#
test_train(loss_test,gen)
test_train(loss_targ_list,gen_targ_list)
test_train(loss_3arg,gen_3args)

Example extending Loss (gives same results)

class TargListLoss(tf.keras.losses.Loss):
    
    def __init__(self,loss_func):
        super(TargListLoss,self).__init__()
        self.loss_func=loss_func
        
    def call(self,target,output):
        weight=target[1]
        target=target[0]
        return weight*self.loss_func(target,output)

Solution

  • SampleWeights!

    I was trying to build custom loss functions that weighted the loss on a per sample basis but this is exactly what sample_weights are for.

    Apologies all for the silly question - though hopefully this keeps others from repeating my mistake. I think missed this because originally I was planning on determining the weight by passing metadata directly to the loss function. In retrospect it doesn't make sense to include the meta-to-weight logic in your loss-function as it is application dependent.

    For completeness the code below shows how passing a third arg from the generator does indeed weight each sample:

    BATCH_SIZE=2
    SIZE=3
    STEPS=8
    EPOCHS=3
    NB_CLASSES=4
    
    
    def gen_inpt(ch_in):
        return tf.random.uniform((BATCH_SIZE,SIZE,SIZE,ch_in))
    
    def gen_targ(nb_classes):
        t=tf.random.uniform((BATCH_SIZE,SIZE,SIZE),maxval=nb_classes,dtype=tf.int32)
        return tf.keras.utils.to_categorical(t,num_classes=nb_classes)
            
    def gen_3args(ch_in,ch_out,dummy_sw):
        if dummy_sw:
            return ( ( gen_inpt(ch_in), gen_targ(ch_out), tf.convert_to_tensor(dummy_sw) ) for b in range(BATCH_SIZE*STEPS*EPOCHS) )
        else:
            return ( ( gen_inpt(ch_in), gen_targ(ch_out) ) for b in range(BATCH_SIZE*STEPS*EPOCHS) )
    
        
    class Toy(tf.keras.Model):
        
        def __init__(self,nb_classes):
            super(Toy, self).__init__()
            self.l1=layers.Conv2D(32,3,padding='same')
            self.l2=layers.Conv2D(nb_classes,3,padding='same')
            
        def call(self,x):
            x=self.l1(x)
            x=self.l2(x)
            return x
        
    loss_func=tf.keras.losses.CategoricalCrossentropy()
    
    def test_train(loss,gen):
        try: 
            model=Toy(NB_CLASSES)    
            model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
            model.fit(gen,steps_per_epoch=STEPS,epochs=EPOCHS)
        except Exception as e:
            print(e)
    
    # #
    # # RUN TESTS
    # #
    print('None: unweighted')
    test_train(loss_func,gen_3args(6,NB_CLASSES,None))
    print('ones: same as None')
    test_train(loss_func,gen_3args(6,NB_CLASSES,[1,1]))
    print('100s: should be roughly 100 times the loss of None')
    test_train(loss_func,gen_3args(6,NB_CLASSES,[100,100]))
    print('[0,10]: should be roughly 1/2 the 100s loss ')
    test_train(loss_func,gen_3args(6,NB_CLASSES,[0,100]))