Search code examples
pythonmachine-learningserializationhyperparametersray

Serialization error using ray-tuner for hyperparameter tuning


I am trying to tune some hyperparameters for my neural network for an image segmentational problem. I set up the tuner as simple as it can be, but when I run my code i get the following error:

2025-02-21 15:19:59,571 INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
2025-02-21 15:20:00,697 INFO tune.py:253 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `tune.run(...)`.
Traceback (most recent call last):
  File "Q:\Msc\Diplomadolgozat\Unet\Pytorch-UNet\train.py", line 232, in <module>
    analysis = tune.run(train_unet_model, config=config)
  File "C:\Users\sajte\AppData\Local\Programs\Python\Python310\lib\site-packages\ray\tune\tune.py", line 758, in run
    experiments[i] = Experiment(
  File "C:\Users\sajte\AppData\Local\Programs\Python\Python310\lib\site-packages\ray\tune\experiment\experiment.py", line 149, in __init__
    self._run_identifier = Experiment.register_if_needed(run)
  File "C:\Users\sajte\AppData\Local\Programs\Python\Python310\lib\site-packages\ray\tune\experiment\experiment.py", line 360, in register_if_needed
    raise type(e)(str(e) + " " + extra_msg) from None
TypeError: ray.cloudpickle.dumps(<class 'ray.tune.trainable.function_trainable.wrap_function.<locals>.ImplicitFunc'>) failed.
To check which non-serializable variables are captured in scope, re-run the ray script with 'RAY_PICKLE_VERBOSE_DEBUG=1'. Other options:
-Try reproducing the issue by calling `pickle.dumps(trainable)`.
-If the error is typing-related, try removing the type annotations and try again.

I found nothing usable about how to resolve it. Here is my code:

Main:

if __name__ == '__main__':
    config={
        "BATCH_SIZE": tune.grid_search([1, 4, 8]),
        "LEARNING_RATE": tune.grid_search([1e-4, 1e-5]),
        "BASE_CHANNELS": tune.grid_search([32, 64, 128]),
        "GRADIENT_CLIPPING": tune.grid_search([0.25, 1.0, 5.0]),
        "WEIGHT_DECAY": tune.grid_search([1e-2, 1e-4, 1e-8]),
        "SCHEDULER_PATIENCE": tune.grid_search([0, 1, 4])
        }
    
    analysis = tune.run(train_unet_model, config=config)
    
    print("Best config: ", analysis.get_best_config(metric="dice_loss"))

train function:

def train_unet_model(config):
    dataset = BasicDataset(dir_img, dir_mask, IMG_SCALE, '_mask')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    model = UNet(n_channels=3, n_classes=CLASSES, bilinear=BILINEAR, base_channels=config["BASE_CHANNELS"])
    model = model.to(memory_format=torch.channels_last)
    model.to(device=device)

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * VAL_PERCENT)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=config["BATCH_SIZE"], num_workers=os.cpu_count(), pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=config["MOMENTUM"], foreach=True)
    optimizer = optim.AdamW(model.parameters(), lr=config["LEARNING_RATE"], weight_decay=config["WEIGHT_DECAY"], foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=config["SCHEDULER_PATIENCE"]) 
    grad_scaler = torch.GradScaler(device, enabled=AMP)
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(1, EPOCHS + 1):
        model.train()
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{EPOCHS}', unit='img') as pbar:
            epoch_loss = 0.0
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=AMP):
                    masks_pred = model(images)
                    loss = criterion(masks_pred, true_masks)
                    loss += dice_loss(
                        F.softmax(masks_pred, dim=1).float(),
                        F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                        multiclass=True
                    )

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config["GRADIENT_CLIPPING"])
                grad_scaler.step(optimizer)
                grad_scaler.update()

                global_step += 1
                epoch_loss += loss.item()

                pbar.update(images.shape[0])
                pbar.set_postfix(**{'loss (batch)': loss.item(), 'loss (epoch)': epoch_loss})
                
                # Evaluation round
                division_step = (n_train // (1 * config["BATCH_SIZE"]))
                if division_step > 0:
                    if global_step % division_step == 0:
                        val_score = evaluate(model, val_loader, device, AMP)
                        scheduler.step(val_score)
                        tune.track.log(dice_loss=val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))                                

                        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
                        state_dict = model.state_dict()
                        state_dict['mask_values'] = dataset.mask_values
                        torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
                        logging.info(f'Checkpoint {epoch} saved!')

My imports and constants:

import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from ray import tune

from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset
from utils.dice_score import dice_loss

dir_img = Path('./data/wildfire/train/')
dir_mask = Path('./data/wildfire/train_masks/')
dir_checkpoint = Path('./checkpoints/')

EPOCHS = 5
VAL_PERCENT = 0.1
IMG_SCALE = 1.0
AMP = True
CLASSES = 2
BILINEAR = False

I found more materials about how to tune with Ray, but each seems more unnecessarily complicated than the other. If someone can direct me to some clean and straightforward material that would be good. But i have no idea what's the issue with the above code.


Solution

  • When i was making this post i cleaned up some logging and argument parsing from the code, which was initially causing the serialization issue. The problematic part in my code was a line where i called the argument parser from the train function as follows:

    def train_unet_model(config):
        ...
        args = get_args()
        model = UNet(n_channels=3,
                     n_classes=args.classes,
                     bilinear=args.bilinear,
                     base_channels=config["BASE_CHANNELS"],
                     kernel_size=config["SAMPLING_KERNEL_SIZE"],
                     use_bias=config["USE_BIAS"],
                     base_mid_channels=config["BASE_MID_CHANNELS"])
        ...
    

    and the argument getter function looking like:

    def get_args():
        parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
        parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
        ...
        return parser.parse_args()
    

    If you want to keep the parser, then use it outside of the trainable and pass the extracted values down via simple variables.