Search code examples
keraspytorchregressionfast-ai

Fastai Regression model with observation weight


Is it possible to have a costume mean squared error function with sample weight for each observation?

I am able to utilize the standard fastai training loop and I am able to implement this costume loss in PyTorch.

How to put that to fastai learner object on tabular data?

I know keras has this already implemented in the .fit method where sample_weight argument present.

def weighted_mse_loss(input, target, weight):
    return torch.sum(weight * (input - target) ** 2)

from fastai.tabular.all import *
import seaborn as sns

df = sns.load_dataset('tips')
df = df.assign(sample_weight = np.random.normal(size = df.shape[0], loc = 10, scale = 2))

y = ['total_bill']
cont = ['tip']
cat = ['sex', 'smoker', 'day', 'time', 'size']

procs = [Normalize, Categorify]

df["Y"] = np.log(df[y] + 1)

MIN = df["Y"].min()
MAX = df["Y"].max()

splits =  RandomSplitter(valid_pct=0.2)(range_of(df))

to = TabularPandas(
    df,
    procs=procs,
    cat_names=cat,
    cont_names=cont,
    y_names="Y",
    splits=splits,
    y_block=RegressionBlock(n_out = 1),
)

dls = to.dataloaders(
    bs=64, shuffle_train=True
)

config = tabular_config(
        embed_p=0.05, 
        y_range=[0, MAX * 1.1],
        bn_final=False,
        ps=[0.05, 0.05, 0.05],
    )

learner = tabular_learner(
        dls,
        layers=[1000, 500, 250],
        config=config,
        wd=0.2,
        metrics=[rmse,],
    )

learner.fit_one_cycle(40, lr_max = 0.01,
                          wd = 0.1)

Solution

  • I am using this workaround:

    1. in y_names for TabularPandas, you can return the tuple of (weight, y) as

      to = TabularPandas(df,
                         procs=procs,
                         cat_names=cat,
                         cont_names=cont,
                         y_names=["sample_weight","Y"],
                         splits=splits,
                         y_block=RegressionBlock(n_out = 1))
      
    2. In your loss function, split your target to (weights, target) and apply weights to loss, e.g.:

       class SampleWeightedCE(torch.nn.modules.loss._Loss):
           def __init__(self):
               super(SampleWeightedCE, self).__init__()
               self.ce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
      
       def forward(self, output, tgt):       
           weights = tgt[:,0].unsqueeze(1)
           target = tgt[:,1].unsqueeze(1)            
      
           losses = self.ce_loss(output, target) * weights
           return torch.sum(losses) / torch.sum(weights)
      
    3. If you want to measure metrics, you can use the same workaround, such as:

      def accuracy_W(inp, tgt, thresh=0.5, sigmoid=True):
       weights = tgt[:,0].unsqueeze(1)
       target = tgt[:,1].unsqueeze(1)
      
       if sigmoid: inp = inp.sigmoid()    
       classes = (inp >= thresh)
       m_target = (target >= 0.5)
       correct = (m_target == classes) 
       return torch.sum(weights * correct) / torch.sum(weights)  
      
    4. In get_preds() or predict(), you need to split the target

       y_prob, y_out = learn.get_preds(ds_idx=1, with_input=False, with_loss=False, reorder=False)  
       weights = y_out[:,0]
       target = y_out[:,1]