Search code examples

Fairseq Transform model not working (Float can't be cast to long)

I've installed python 3.8, pytorch 1.7, and fairseq 0.10.1, on a new machine, then copied in a script and model from a machine with python 3.6, pytorch 1.4 and fairseq 0.9.0, where it is working.

The model is loaded and prepared with:

model = TransformerModel.from_pretrained(...)

Then used with:

inputs = [model.binarize(encode(src, str)) for str in texts]
batched_hypos = model.generate(inputs, beam)

inputs looks like e.g. [tensor([ 116, 1864, 181, 6, 2]), tensor([ 5, 432, 7, 2])]

It asserts, with the last bit of the call stack being:

    batched_hypos = model.generate(inputs, beam)
  File "/path/to/fairseq/", line 125, in generate
    sample = self._build_sample(tokens)
  File "/path/to/fairseq/", line 196, in _build_sample
    assert torch.is_tensor(src_tokens)

If instead I use fairseq-interactive from the commandline it fails with RuntimeError: result type Float can't be cast to the desired output type Long. (Full stack trace below.)

As using the cli also fails, my hunch is that my model built with fairseq 0.9.x cannot be used with fairseq 0.10.x. If so, is there a way to update the model (i.e. without having to retrain it). And if not, what could the problem be, and how do I fix it?

BTW, exactly the same error if I add --cpu to the commandline args, so the GPU or cuda version can be eliminated as a possible cause.

$ fairseq-interactive path/to/dicts --path models/ --source-lang ja --target-lang en  --remove-bpe sentencepiece

  File "/path/to/bin/fairseq-interactive", line 11, in <module>
  File "/path/to/lib/python3.8/site-packages/fairseq_cli/", line 190, in cli_main
  File "/path/to/lib/python3.8/site-packages/fairseq_cli/", line 149, in main
    translations = task.inference_step(generator, models, sample)
  File "/path/to/lib/python3.8/site-packages/fairseq/tasks/", line 265, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/path/to/lib/python3.8/site-packages/torch/autograd/", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/path/to/lib/python3.8/site-packages/fairseq/", line 113, in generate
    return self._generate(model, sample, **kwargs)
  File "/path/to/lib/python3.8/site-packages/torch/autograd/", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/path/to/lib/python3.8/site-packages/fairseq/", line 376, in _generate
    cand_scores, cand_indices, cand_beams =
  File "/path/to/lib/python3.8/site-packages/fairseq/", line 81, in step
    torch.div(self.indices_buf, vocab_size, out=self.beams_buf)
RuntimeError: result type Float can't be cast to the desired output type Long


  • (UPDATE: the below instructions install pytorch without GPU support. Going back to using the pytorch channel gets GPU support but fairseq breaks again. I've not yet cracked the secret code to get everything working together.)

    Solved this by wiping conda and starting again; I've decided to self-answer, rather than delete the question, as those error messages turned out to be useless (to put it politely) so maybe it will help someone else when they google.

    First: I actually had fairseq 0.9.0 installed. Even though 0.10.1 was listed first on conda-forge. This obviously meant my hunch was wrong, and something more obscure was at work. I then couldn't get uninstall or upgrade to work. Hence my decision to wipe anaconda completely and start again.

    Second, I noticed something deep in the conda documentation saying to install everything in one go, to avoid conflicts. Not my definition of how a package manager should work, but anyway.

    Third, I created a "test" conda environment, rather than using the "base" default. I suspect this had nothing to do with getting it to work, but I mention it just in case.

    So, my successful install command was:

    conda install -c conda-forge pytorch cudatoolkit=11.0 nvidia-apex fairseq==0.10.1 sentencepiece

    This gives me python 3.7.9 (not the 3.8.5 the OS has installed), pytorch 1.7.1, fairseq 0.10.1, and sentencepiece 0.1.92.