I’m encountering an issue when switching my device to mps. My training runs smoothly on cpu, but when I set the device to mps, I get the following error:
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Could you please help me understand why this error occurs and how I might resolve it? Thank you very much for your assistance!
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image, ImageDraw, ImageFont
import random
import string
def generate_text_image(text, width=384, height=96):
image = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 32)
except:
font = ImageFont.load_default()
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
x = (width - text_width) // 2
y = (height - text_height) // 2
draw.text((x, y), text, fill="black", font=font)
return image
class OCRDataset(Dataset):
def __init__(self, num_samples=1000, processor=None):
self.processor = processor
self.samples = []
chars = string.ascii_letters + string.digits
for _ in range(num_samples):
text = "".join(random.choices(chars, k=random.randint(5, 10)))
image = generate_text_image(text, 230, 100)
self.samples.append((image, text))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
image, text = self.samples[idx]
pixel_values = self.processor(image, return_tensors="pt").pixel_values
labels = self.processor.tokenizer(
text, padding="max_length", max_length=20, return_tensors="pt"
).input_ids
return {"pixel_values": pixel_values.squeeze(), "labels": labels.squeeze()}
def main():
processor = TrOCRProcessor.from_pretrained(
"microsoft/trocr-base-handwritten", use_fast=True
)
model = VisionEncoderDecoderModel.from_pretrained(
"microsoft/trocr-base-handwritten"
)
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
train_dataset = OCRDataset(num_samples=1000, processor=processor)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
device = "mps" # or `cpu``
print(f"Training on device: {device}")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
model.train()
for epoch in range(num_epochs):
total_loss = 0
print(f"\nStarting Epoch {epoch+1}/{num_epochs}")
for batch_idx, batch in enumerate(train_dataloader):
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (batch_idx + 1) % 10 == 0:
current_loss = total_loss / (batch_idx + 1)
print(
f"Batch {batch_idx+1}/{len(train_dataloader)} | "
f"Current Loss: {current_loss:.4f}"
)
avg_loss = total_loss / len(train_dataloader)
print(f"\nEpoch {epoch+1} Summary:")
print(f"Average Loss: {avg_loss:.4f}")
print("-" * 50)
print("\nTraining completed!")
print(f"Saving model to: models/trocr")
model.save_pretrained("models/trocr")
processor.save_pretrained("models/trocr")
print("Model saved successfully!")
if __name__ == "__main__":
main()
I have confirmed that MPS is available on the Mac with the M4 chip.
I faced the same issue with other transformers models.
I cant retrieve the GitHub issue associated to your problem, but it has not been fixed yet. All I know is that if you want to use some transformers pretrained models you can't use 'mps' as a device.
One possibility that you have is using your 'cpu' which signify much more training time.
device = "cpu"
instead of
device = "mps" # or `cpu``
Another possibility is to use similar pretrained models supporting mps configuration. From my personal experience, the timm library, containing SOTA computer vision models (https://huggingface.co/timm), is working quite smoothly with mps processor (I have the M3 Pro).
Best, Arnaud