R package tabnet, how to change the loss to balanced accuracy?

In the tabnet package, I want the loss to be the balanced accuracy for multi-class classification. Similar to yardstick::bal_accuracy_vec(). How can I do that?

I do know how to compute a balanced accuracy but I don't know how to create a function that would fit in the tabnet framework. So basically, any help in how to customize the loss in tabnet is welcome.



data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)

rec <-
  recipe(Attrition + JobSatisfaction ~ ., data = attrition[ids, ]) %>%

attrition_fit <-
             data = attrition[ids, ],
             epochs = 2,
             valid_split = 0.2,
             loss = yardstick::bal_accuracy_vec

gives the error:Error in x != y : comparison (2) is possible only for atomic and list types.

  • After asking the same thing on the tabnet guthub repo, I got the answer :

    Custom loss is supported as a function in {tabnet}

    The good news is that you have an example in here :

    The bad news is that you must rely on {torch} loss, and you can not use {yardstick} metric as a loss. The reason is that the function must be differentiable to drive the gradient in the right direction...

    There is plenty of loss functions in {torch}, all ending with _loss. If one is missing, you can also build yours as a torch module like entmax and sparsemax in

