Search code examples
pythonmacospytorchocrfine-tuning

TrOCR fine tuning in Mac M4 chip (MPS)


I’m encountering an issue when switching my device to mps. My training runs smoothly on cpu, but when I set the device to mps, I get the following error:

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Could you please help me understand why this error occurs and how I might resolve it? Thank you very much for your assistance!

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image, ImageDraw, ImageFont
import random
import string


def generate_text_image(text, width=384, height=96):
    image = Image.new("RGB", (width, height), color="white")
    draw = ImageDraw.Draw(image)

    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 32)
    except:
        font = ImageFont.load_default()

    bbox = draw.textbbox((0, 0), text, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]
    x = (width - text_width) // 2
    y = (height - text_height) // 2
    draw.text((x, y), text, fill="black", font=font)

    return image


class OCRDataset(Dataset):
    def __init__(self, num_samples=1000, processor=None):
        self.processor = processor
        self.samples = []
        chars = string.ascii_letters + string.digits
        for _ in range(num_samples):
            text = "".join(random.choices(chars, k=random.randint(5, 10)))
            image = generate_text_image(text, 230, 100)
            self.samples.append((image, text))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image, text = self.samples[idx]
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(
            text, padding="max_length", max_length=20, return_tensors="pt"
        ).input_ids

        return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()}


def main():
    processor = TrOCRProcessor.from_pretrained(
        "microsoft/trocr-base-handwritten", use_fast=True
    )
    model = VisionEncoderDecoderModel.from_pretrained(
        "microsoft/trocr-base-handwritten"
    )

    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.vocab_size = model.config.decoder.vocab_size

    train_dataset = OCRDataset(num_samples=1000, processor=processor)
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

    device = "mps"  # or `cpu``
    print(f"Training on device: {device}")
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    num_epochs = 3
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        print(f"\nStarting Epoch {epoch+1}/{num_epochs}")
        for batch_idx, batch in enumerate(train_dataloader):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if (batch_idx + 1) % 10 == 0:
                current_loss = total_loss / (batch_idx + 1)
                print(
                    f"Batch {batch_idx+1}/{len(train_dataloader)} | "
                    f"Current Loss: {current_loss:.4f}"
                )

        avg_loss = total_loss / len(train_dataloader)
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"Average Loss: {avg_loss:.4f}")
        print("-" * 50)

    print("\nTraining completed!")
    print(f"Saving model to: models/trocr")

    model.save_pretrained("models/trocr")
    processor.save_pretrained("models/trocr")
    print("Model saved successfully!")


if __name__ == "__main__":
    main()

I have confirmed that MPS is available on the Mac with the M4 chip.


Solution

  • I faced the same issue with other transformers models.

    I cant retrieve the GitHub issue associated to your problem, but it has not been fixed yet. All I know is that if you want to use some transformers pretrained models you can't use 'mps' as a device.

    One possibility that you have is using your 'cpu' which signify much more training time.

    device = "cpu"  
    

    instead of

    device = "mps"  # or `cpu``
    

    Another possibility is to use similar pretrained models supporting mps configuration. From my personal experience, the timm library, containing SOTA computer vision models (https://huggingface.co/timm), is working quite smoothly with mps processor (I have the M3 Pro).

    Best, Arnaud