Search code examples
modelyoloyolov7

How to train a yolov7 model on our own dataset?


When I try to train my model by executing the code cell below:

python train.py --img-size 2048 --cfg cfg/training/yolov7.yaml --hyp data/road_sign_data.yaml --batch 8 --epochs 100 --data data/road_sign.yaml --weights yolov7_training.pt --workers 24 --name yolo_road_det

I have the following error message :

Traceback (most recent call last): File "C:\Users\531558\Documents\streamline2\yolov7\train.py", line 12, in import torch.distributed as dist File "C:\Users\531558\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch_init_.py", line 141, in raise err OSError: [WinError 126] The specified module could not be found. Error loading "C:\Users\531558\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\lib\shm.dll" or one of its dependencies.

It looks like it can not import torch.distibuted

I tried to change the version of python I am using (from 3.12 to 3.11.9) but it still does not work. I also tried many other way to do the training of a yolov7 model but none of them were working... If you have any solution it would be very helpful


Solution

  • You need to install pytorch on your system and check that it works:

    https://pytorch.org/

    ^^ scroll down and pick the appropriate system you're using, pip or conda, windows or mac or linux, Cuda or not and then copy/paste the command it gives you to install torch/pytorch and related libs.

    Can check the install worked by running:

    python -c "import torch; print(torch.__version__)"

    Should see output something like:

    2.3.0+cu118

    or whatever version of pytorch you got... if have an nvidia card definitely get the CUDA support ones.


    Also worth noting most AI things work with CUDA 11.8, only some things will work with or use anything from the 12.1 API/version and if your card supports 12.1 it'll support 11.8 from what I've seen (use nvidia-smi to see some details about your driver if have nvidia card)