Search code examples
pythonpytorchclassificationhuggingface

Training wav2vec2 for multiple (classification) tasks


I trained a wav2vec2 model using pytorch and huggingface transformer. Here is the code: https://github.com/padmalcom/wav2vec2-nonverbalvocalization

I now want to train the model on a second tasks, e.g. age classification or speech recognition (ASR).

My problem is, that I do not really understand how I can configure my model to accept a seconds input and train another output. Can anybody give me a short explaination?

I know that I have to use multiple heads in my model and that the thing I want to achieve is called "multi task learning". My problem is, that I don't know how to write the model for that.


Solution

  • This would be easier to accomplish if you are ok with giving up some performance for the sake of simplicity.

    This answer is formulated based on the assumption that you're new to multi-task / joint learning strategies and are looking for something simple to start with.

    Approach 1

    Because Wav2Vec2's was designed as a CTC model, you can just easily initialize second classification head with the exact same architecture as here, and sample inputs and labels for both of your tasks at the same time, in the same dataloader. This will get a bit more complicated if you want to do ASR as a second task, but would be easy enough if you choose to just focus on Sequence Classification in both cases.

    Now, let's say you have two merged datasets, where columns are:

    1. audio_path: str
    2. label: int
    3. task: int

    In the DataCollator, you would normally concatenate everything with relative tensors as described here, but with one minor change: you return batched task index tensor along with inputs and labels.

    Afterwards, you can can reuse that information to split up the hidden states to route them to different classification heads after this line. i.e. if your task tensor looks something like this: torch.tensor([[0, 0, 1, 1]]), you can use hidden_states[:2, :, :] as a first classification head input, and hidden_states[2:, :, :] for second classification head. Same goes for labels as well.

    But probably they'll be fairly distributed ([0, 1, 0, 0, 1, ...]) - in this case you can append each of the batches to the relative list and then concatenate them.

    This is to avoid complexity of multi-task learning and convert this task into more of joint-learning approach for simplicity.

    Your forward pass to classification heads would look like this:

    hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
    hidden_states_1 = ... # As described above
    hidden_states_2 = ... # As described above
    labels_1 = ... # As described above
    labels_2 = ... # As described above
    task_1_logits = self.classifier_1(hidden_states_1)
    task_2_logits = self.classifier_2(hidden_states_2)
    

    As you get logits for both tasks, you'll need to calculate loss over them separately, and then do either sum, mean over them, or multiply each of them by some weight in advance.

    It would look like this:

      loss_1 = loss_fct(logits_1.view(-1, self.num_labels_1), labels_1.view(-1))
      loss_2 = loss_fct(logits_2.view(-1, self.num_labels_2), labels_2.view(-1))
      total_loss = (loss_1 * 0.5) + (loss_2 * 0.5)
    

    Please note, that there will be some things to consider anyway, such as - you might not end up with having data for both tasks in some batches in case you're not planning to write custom dataloader.

    This approach won't yield SOTA results that you can put into production (at least without lot of further optimizations), but will probably be alright for experimentation and private usage.

    Approach 2

    Easier way to go with this is as follows:

    1. freeze Wav2Vec2Model
    2. Train classification head for first task and save weights
    3. Train classification head for second task and save weights.
    4. Initialize two classification heads during inference and load trained weights accordingly
    5. Do forward passes through either of the heads depending on what you want to do with your inputs.

    This approach will yield worse results as transformer layers of Wav2Vec2 will not be finetuned.