I am trying to run the LLAMA2 inference script (shown below) with vscode debugging mode:
torchrun --nproc_per_node 1 example_text_completion.py \
--ckpt_dir models/7B-Chat \
--tokenizer_path tokenizer.model \
--max_seq_len 128 --max_batch_size 4
Before this, I can successfully run it with my command line interface, which shows my python environment is correct.
I have tried these two debug configs below:
{
"name": "Python: run_llama2_inference",
"type": "python",
"request": "launch",
"module": "torchrun",
"args": [
"--nproc_per_node=1",
"example_chat_completion.py",
"--ckpt_dir=models/7B-Chat/",
"--tokenizer_path=tokenizer.model",
"--max_seq_len=512",
"--max_batch_size=4",
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
Corresponding Error Messsage: "No module named torchrun"
{
"name": "Python: run_llama2_inference",
"type": "python",
"request": "launch",
"module": "torch.distributed.launch",
"args": [
"--use-env",
"example_chat_completion.py",
"--nproc_per_node=1",
"--ckpt_dir=models/7B-Chat/",
"--tokenizer_path=tokenizer.model",
"--max_seq_len=512",
"--max_batch_size=4",
],
"console": "integratedTerminal",
"justMyCode": true,
"env":
"PYTHONPATH": "${workspaceFolder}"
}
},
Corresponding Error Messsage: "Could not consume arg: --nproc_per_node=1"
Both configs do not work as expected. I would like to seek the advice from online experts. Appreciate your ideas or advice in advance!
You can use the "program"
field to specify the Python script you want to run (example_text_completion.py
), and pass the rest of the arguments using the "args"
field.
Here's an example of how you can modify your launch configuration:
{
"name": "Python: run_llama2_inference",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/example_chat_completion.py",
"args": [
"--ckpt_dir=models/7B-Chat/",
"--tokenizer_path=tokenizer.model",
"--max_seq_len=512",
"--max_batch_size=4",
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
}