Search code examples
pythonhuggingface-transformerswith-statementcontextmanagerhuggingface-trainer

HuggingFace BetterTransformer in `with` context - cannot disable after context


I am writing a custom with context manager to temporarily make the model a BetterTransformer model while calling trainer.evaluate().

I evaluated before, in, and after the with context. I noticed that the evaluation after the with context still uses BetterTransformer. This is a problem because the trainer.train() call afterwards will also use BetterTransformer, resulting in poor training due to padding.

How do I create a custom with context that only uses BetterTransformer inside the context, not afterwards?

Please find the MWE gist here.

I created a custom context manager:

class BetterTransformerContext:
    """Temporarily replace a model with a BetterTransformer model."""

    def __init__(self, model):
        self.model = model
        self.original_model = None

    def __enter__(self):
        self.original_model = self.model
        self.model = BetterTransformer.transform(self.model)
        return self.model

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.model = self.original_model
        # self.model = BetterTransformer.reverse(self.model)  # NOTE: same result

The output is as follows. Evaluating without BetterTransformer handles approximately 100 it/s, with BetterTransformer handles approximately 115 it/s. As you can see, evaluating after the context still results in 115 it/s.

========== Without Optimum (-> should be slow) ==========
BT before context:  False
100%|█████████████████████████████| 204/204 [00:01<00:00, 103.09it/s]
0.3161764705882353
========== With Optimum (-> should be fast) ==========
BT in context:  True
100%|█████████████████████████████| 204/204 [00:01<00:00, 116.68it/s]
0.3161764705882353
========== Without Optimum (-> should be slow) ==========
BT after context:  True
100%|█████████████████████████████| 204/204 [00:01<00:00, 116.53it/s]
0.3161764705882353

Solution

  • I found a solution by using a custom context manager on the trainer object, as opposed to applying it on a model object.

    The custom context manager is as follows:

    class BetterTransformerTrainerContext:
        """Context manager to wrap trainer.model with BetterTransformer."""
        def __init__(self, trainer):
            self.trainer = trainer
    
        def __enter__(self):
            self.trainer.model = BetterTransformer.transform(
                self.trainer.model, keep_original_model=True
            )
            return self.trainer
    
        def __exit__(self, exc_type, exc_val, exc_tb):
            self.trainer.model = BetterTransformer.reverse(self.trainer.model)
    

    It can be used as follows:

    print("=" * 10, "With Optimum (-> should be fast)", "=" * 10)
    with BetterTransformerTrainerContext(trainer) as _optimum_trainer:
        eval_accuracy = _optimum_trainer.evaluate()["eval_accuracy"]
        print(eval_accuracy)
    

    I hope this might be helpful to someone else.