Search code examples
nlphuggingface-transformershuggingfacefine-tuningpeft

Target modules for applying PEFT / LoRA on different models


I am looking at a few different examples of using PEFT on different models. The LoraConfig object contains a target_modules array. In some examples, the target modules are ["query_key_value"], sometimes it is ["q", "v"], sometimes something else.

I don't quite understand where the values of the target modules come from. Where in the model page should I look to know what the LoRA adaptable modules are?

One example (for the model Falcon 7B):

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]

Another example (for the model Opt-6.7B):

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

Yet another (for the model Flan-T5-xxl):

lora_config = LoraConfig(
 r=16,
 lora_alpha=32,
 target_modules=["q", "v"],
 lora_dropout=0.05,
 bias="none",
 task_type=TaskType.SEQ_2_SEQ_LM
)

Solution

  • Let's say that you load some model of your choice:

    model = AutoModelForCausalLM.from_pretrained("some-model-checkpoint")

    Then you can see available modules by printing out this model:

    print(model)

    You will get something like this (SalesForce/CodeGen25):

    LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(51200, 4096, padding_idx=0)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
              (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
              (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
              (act_fn): SiLUActivation()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
        (norm): LlamaRMSNorm()
      )
      (lm_head): Linear(in_features=4096, out_features=51200, bias=False)
    )
    

    In my case, you can find the LLamaAttention module that contains q_proj, k_proj, v_proj, and o_proj. And this are some modules available for LoRA.

    I suggest you reading more about which modules to use in LoRA paper.