Search code examples
machine-learningpytorchtranslationtorchfairseq

how to get alignment or attention information for translations produced by a torch hub model?


The torch hub provides pretrained models, such as: https://pytorch.org/hub/pytorch_fairseq_translation/

These models can be used in python, or interactively with a CLI. With the CLI it is possible to get alignments, with the --print-alignment flag. The following code works in a terminal, after installing fairseq (and pytorch)

curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
MODEL_DIR=wmt14.en-fr.fconv-py
fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr \
    --tokenizer moses \
    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes \ 
    --print-alignment

In python it is possible to specify the keyword args verbose and print_alignment:

import torch

en2fr = torch.hub.load('pytorch/fairseq', 'transformer.wmt14.en-fr', tokenizer='moses', bpe='subword_nmt')

fr = en2fr.translate('Hello world!', beam=5, verbose=True, print_alignment=True)

However, this will only output the alignment as a logging message. And for fairseq 0.9 it seems to be broken and results in an error message (issue).

Is there a way to access alignment information (or possibly even the full attention matrix) from python code?


Solution

  • I've browsed the fairseq codebase and found a hacky way to output alignment information. Since this requires editing the fairseq sourcecode itself, I don't think it's an acceptable solution. But maybe it helps someone (I'm still very interested in an answer on how to do this properly).

    Edit the sample() function and rewrite the return statement. Here is the whole function (to help you find it better, in the code), but only the last line should be changed:

    def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
        if isinstance(sentences, str):
            return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
        tokenized_sentences = [self.encode(sentence) for sentence in sentences]
        batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
        return list(zip([self.decode(hypos[0]['tokens']) for hypos in batched_hypos], [hypos[0]['alignment'] for hypos in batched_hypos]))