Search code examples
juliaflux.jlflux-machine-learning

How to use VGG19 in Flux.jl?


I have a specific computer vision problem that I want to try solving using some pre-trained models. The Flux.jl docs don't actually have any pre-trained models in them like some of the other ML frameworks (PyTorch as an example). How would I access those sort of pertained models in Flux?


Solution

  • In the Flux ecosystem, the functionality for something like pre-trained computer vision models has been extrapolated out into a separate package called MetalHead.jl: https://github.com/FluxML/Metalhead.jl

    Per the docs there, you can create a VGG19 model by doing:

    julia> vgg19 = VGG19()
    VGG19()
    

    and then you can pass the model to something like the classify function along with an input image for a validation test.