Search code examples
pythondeep-learningpytorchtorchtext

Torchtext BucketIterator wrapper from tutorial produces SyntaxError


I am following and implementing code from this short tutorial on Torchtext, which is surprisingly clear given the poor documentation of Torchtext.

When the Iterator has been created (the batch generator) he proposes to create a wrapper to produce more reusable code. (See step 5 in the tutorial).

The code contains a surprisingly long and weird line, which I don't understand and which raises a SyntaxError: invalid syntax. Does anyone have a clue of what is going on?

(The problematic line is the one that starts with: if self.y_vars is <g [...])

class BatchWrapper:
  def __init__(self, dl, x_var, y_vars):
        self.dl, self.x_var, self.y_vars = dl, x_var, y_vars # we pass in the list of attributes for x <g class="gr_ gr_3178 gr-alert gr_spell gr_inline_cards gr_disable_anim_appear ContextualSpelling ins-del" id="3178" data-gr-id="3178">and y</g>

  def __iter__(self):
        for batch in self.dl:
              x = getattr(batch, self.x_var) # we assume only one input in this wrapper

              if self.y_vars is <g class="gr_ gr_3177 gr-alert gr_gramm gr_inline_cards gr_disable_anim_appear Grammar replaceWithoutSep" id="3177" data-gr-id="3177">not</g> None: # we will concatenate y into a single tensor
                    y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
              else:
                    y = torch.zeros((1))

              yield (x, y)

  def __len__(self):
        return len(self.dl)

Solution

  • Yeah, I guess there is some typo from the author. I think the correct piece of code is this:

    if self.y_vars is not None:
        y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
    else:
        y = torch.zeros((1))
    

    You can see this typo in the comment of line 3 also (in the code in blogpost).