Search code examples
machine-learningdeep-learningpytorchtransformer-modelstable-diffusion

Inference error after training an IP-Adapter plus model


I downloaded packages from https://github.com/tencent-ailab/IP-Adapter

run the commands to train an IP-Adapter plus model (input: text + image, output: image):

accelerate launch --num_processes 2 --multi_gpu --mixed_precision "fp16" \
  tutorial_train_plus.py \
  --pretrained_model_name_or_path="stable-diffusion-v1-5/" \
  --image_encoder_path="models/image_encoder/" \
  --data_json_file="assets/prompt_image.json" \
  --data_root_path="assets/train/" \
  --mixed_precision="fp16" \
  --resolution=512 \
  --train_batch_size=2 \
  --dataloader_num_workers=4 \
  --learning_rate=1e-04 \
  --weight_decay=0.01 \
  --output_dir="out_model/" \
  --save_steps=3

During training, there is the message but the training can be continued:

Removed shared tensor {'adapter_modules.27.to_k_ip.weight', 'adapter_modules.1.to_v_ip.weight', 'adapter_modules.31.to_k_ip.weight', 'adapter_modules.15.to_k_ip.weight', 'adapter_modules.31.to_v_ip.weight', 'adapter_modules.11.to_k_ip.weight', 'adapter_modules.23.to_k_ip.weight', 'adapter_modules.3.to_k_ip.weight', 'adapter_modules.25.to_v_ip.weight', 'adapter_modules.21.to_k_ip.weight', 'adapter_modules.17.to_v_ip.weight', 'adapter_modules.13.to_k_ip.weight', 'adapter_modules.17.to_k_ip.weight', 'adapter_modules.19.to_v_ip.weight', 'adapter_modules.13.to_v_ip.weight', 'adapter_modules.7.to_v_ip.weight', 'adapter_modules.7.to_k_ip.weight', 'adapter_modules.29.to_k_ip.weight', 'adapter_modules.3.to_v_ip.weight', 'adapter_modules.5.to_v_ip.weight', 'adapter_modules.21.to_v_ip.weight', 'adapter_modules.5.to_k_ip.weight', 'adapter_modules.23.to_v_ip.weight', 'adapter_modules.25.to_k_ip.weight', 'adapter_modules.1.to_k_ip.weight', 'adapter_modules.9.to_v_ip.weight', 'adapter_modules.9.to_k_ip.weight', 'adapter_modules.15.to_v_ip.weight', 'adapter_modules.27.to_v_ip.weight', 'adapter_modules.29.to_v_ip.weight', 'adapter_modules.19.to_k_ip.weight', 'adapter_modules.11.to_v_ip.weight'} while saving. This should be OK, but check by verifying that you don't receive anywarning while reloading

After training is finished and convert the weight to generate ip_adapter.bin, then run the inference code ip_adapter-plus_demo.py with the following model paths in this file:

base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "models/image_encoder"
ip_ckpt = "out_model/demo_plus_checkpoint/ip_adapter.bin"

It shows the error:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ModuleList:
        Missing key(s) in state_dict: "1.to_k_ip.weight", "1.to_v_ip.weight", "3.to_k_ip.weight", "3.to_v_ip.weight", "5.to_k_ip.weight", "5.to_v_ip.weight", "7.to_k_ip.weight", "7.to_v_ip.weight", "9.to_k_ip.weight", "9.to_v_ip.weight", "11.to_k_ip.weight", "11.to_v_ip.weight", "13.to_k_ip.weight", "13.to_v_ip.weight", "15.to_k_ip.weight", "15.to_v_ip.weight", "17.to_k_ip.weight", "17.to_v_ip.weight", "19.to_k_ip.weight", "19.to_v_ip.weight", "21.to_k_ip.weight", "21.to_v_ip.weight", "23.to_k_ip.weight", "23.to_v_ip.weight", "25.to_k_ip.weight", "25.to_v_ip.weight", "27.to_k_ip.weight", "27.to_v_ip.weight", "29.to_k_ip.weight", "29.to_v_ip.weight", "31.to_k_ip.weight", "31.to_v_ip.weight".

Any step wrong to cause this error?


Solution

  • The model can be trained and inferenced successfully now: Set safe_serialization to False in model training file tutorial_train_plus.py:

    accelerator.save_state(save_path, safe_serialization=False)
    

    It will generate pytorch_model.bin instead of model.safetensors during training.

    Once training is complete, modify the model conversion code as below based on the original instructions in readme:

    ckpt = "pytorch_model.bin" # set correct path
    sd = torch.load(ckpt)
    

    Model file ip_adapter.bin will be generated for inference.