Search code examples
python-3.xmachine-learningyoloyolov5yolov8

Using a Python function, how can I trigger the training function in train.py


I am currently using the command-line command to train my yolov5 model: python train.py --img 512 --batch 14 --epochs 5000 --data neurons.yaml --weights yolov5s.pt --cache ram

However, I want to trigger the training process using the train() method in the train.py file. Here is the code I am using to run it as a subprocess:

subprocess.run(['python3.10',
                'yolov5/train.py',
                '--img',
                'weights/last_yolov5s_custom.pt',
                '--img',
                '512',
                '--batch',
                '14',
                '--epochs',
                '2',
                '--data',
                'neurons.yaml',
                '--weights',
                'yolov5s.pt',
                '--cache',
                'ram'])

I would like to know how to pass all the command-line parameters I am currently using to the train() function in the train.py file. The train() function expects a parameter called opt, which is a custom argparse object as shown in the Extras.

Here is the train() function definition in the train.py file:

def train(hyp, opt, device, callbacks):
    # Code implementation

The opt parameter is obtained from the parse_opt() function, which uses argparse for parsing the command-line arguments.

How can I modify my code to trigger the training process using the train() method and pass the necessary command-line parameters?

Extras

opt object originated from the main method;

if __name__ == '__main__':
    opt = parse_opt()
    main(opt)

Here is the parse_opt() method which is using argparse parser;

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
    parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
    parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
    parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
    parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    parser.add_argument('--noval', action='store_true', help='only validate final epoch')
    parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
    parser.add_argument('--noplots', action='store_true', help='save no plot files')
    parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
    parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
    parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
    parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
    parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
    parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
    parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
    parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
    parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
    parser.add_argument('--name', default='exp', help='save to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--quad', action='store_true', help='quad dataloader')
    parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
    parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
    parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
    parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
    parser.add_argument('--seed', type=int, default=0, help='Global training seed')
    parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')

    # Logger arguments
    parser.add_argument('--entity', default=None, help='Entity')
    parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option')
    parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval')
    parser.add_argument('--artifact_alias', type=str, default='latest', help='Version of dataset artifact to use')

    return parser.parse_known_args()[0] if known else parser.parse_args()

Solution

  • argparse.parse_args returns an argparse.Namespace object, which is just a simple "bag of attributes" object. Per the docs:

    This class is deliberately simple, just an object subclass with a readable string representation

    You can pass any object that has attributes set for each of the CLI arguments for opt. For example, any of

    opt = object()
    opt.img = 512
    # ...
    

    or

    import argpase
    opt = argparse.Namespace(
        img=512,
        # ...
    )
    

    or

    import types
    opt = types.SimpleNamespace(
        img=512,
        # ...
    )
    

    would do the trick. Be mindful that many of the CLI arguments have default values that get populated into opt, so you will likely need to specify more arguments than what you've provided when using the CLI interface.

    Alternatively, you could create a modified version of parse_opt that directly interprets you're existing CLI arguments into the same namespace object that gets created when running train.py:

    def parse_opt_modified(raw_args: list[str]) -> argparse.Namespace:
        parser = argparse.ArgumentParser()
        # ....
        return parser.parg_args(raw_args)
    

    then in your code, call it as

    opt = parse_opt_modified(
        [
             '--img',
             'weights/last_yolov5s_custom.pt',
             '--img',
             '512',
             '--batch',
             '14',
             '--epochs',
             '2',
             '--data',
             'neurons.yaml',
             '--weights',
             'yolov5s.pt',
             '--cache',
             'ram,
        ]
    )