Search code examples
pythonpytorchnixnixospytorch-geometric

How can I get torch-geometric to work using Nix?


I am trying to get the Python package torch-geometric to work using Nix (I am on NixOS). Currently, I use mach-nix to try and setup a Python environment. However, the difficulty is that some of the dependencies should be downloaded from a separate file server (not pypi), i.e. https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html. I am first trying to setup an environment containing a single torch-geometric dependency: torch-sparse.

Currently I have the following shell.nix:

{ pkgs ? import <nixpkgs> {} }:

let
  mach-nix = import (builtins.fetchGit {
    url = "https://github.com/DavHau/mach-nix/";
    ref = "refs/tags/3.3.0";
  }) {
    python = "python38";
  };
  sparse = mach-nix.buildPythonPackage {
    pname = "torch_sparse";
    version = "0.6.9";
    requirements = ''
      torch
      scipy
      pytest
      pytest-cov
      pytest-runner
    '';
    src = builtins.fetchGit {
      url = "https://github.com/rusty1s/pytorch_sparse";
      ref = "refs/tags/0.6.9";
    };
  };
in mach-nix.mkPython {
  requirements = "torch-sparse";
  packagesExtra = [
    sparse
  ];
}

Which, upon running nix-shell, fails with the following error message:

running build_ext
error: [Errno 2] No such file or directory: 'which'
builder for '/nix/store/fs9nrrd2a233xp5d6njy6639yjbxp4g0-python3.8-torch_sparse-0.6.9.drv' failed with exit code 1

I tried adding the which package to either checkInputs and buildInputs, but that does not solve the problem. Evidently, I try to build the package directly from its GitHub repo, as I am unsure on how to reference a wheel package in mach-nix. I am relatively new to the NixOS environment, and, quite frankly, I am completely lost.

How should I go about installing a Python package such as torch-sparse or torch-geometric? Am I even using the correct tools?


Solution

  • I have managed to come up with a working Nix expression. I will leave the answer here for future reference. Running the following expression using nix-shell will create a shell with torch-1.8.0 and torch-geometric-1.7.0 and their required dependencies.

    { pkgs ? import <nixpkgs> { } }:
    
    let
      python = pkgs.python38;
      pytorch-180 = let
        pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
        unsupported = throw "Unsupported system";
        version = "1.8.0";
      in python.pkgs.buildPythonPackage {
        inherit version;
    
        pname = "pytorch";
    
        format = "wheel";
    
        src = pkgs.fetchurl {
          name = "torch-${version}-cp38-cp38-linux_x86_64.whl";
          url =
            "https://download.pytorch.org/whl/cu111/torch-${version}%2Bcu111-cp38-cp38-linux_x86_64.whl";
          hash = "sha256-4NYiAkYfGXm3orLT8Y5diepRMAg+WzJelncy2zJp+Ho=";
        };
    
        nativeBuildInputs = with pkgs; [ addOpenGLRunpath patchelf ];
    
        propagatedBuildInputs = with python.pkgs; [
          future
          numpy
          pyyaml
          requests
          typing-extensions
        ];
    
        postInstall = ''
          # ONNX conversion
          rm -rf $out/bin
        '';
    
        postFixup = let rpath = pkgs.lib.makeLibraryPath [ pkgs.stdenv.cc.cc.lib ];
        in ''
          find $out/${python.sitePackages}/torch/lib -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
            echo "setting rpath for $lib..."
            patchelf --set-rpath "${rpath}:$out/${python.sitePackages}/torch/lib" "$lib"
            addOpenGLRunpath "$lib"
          done
        '';
    
        pythonImportsCheck = [ "torch" ];
    
        meta = with pkgs.lib; {
          description =
            "Open source, prototype-to-production deep learning platform";
          homepage = "https://pytorch.org/";
          changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
          license = licenses.unfree; # Includes CUDA and Intel MKL.
          platforms = platforms.linux;
          maintainers = with maintainers; [ danieldk ];
        };
      };
      sparse = with python.pkgs;
        buildPythonPackage rec {
          pname = "torch_sparse";
          version = "0.6.9";
    
          src = pkgs.fetchurl {
            name = "${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            url =
              "https://pytorch-geometric.com/whl/torch-1.8.0+cpu/${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            hash = "sha256-6dmZNQ0FlwKdfESKhvv8PPwzgsJFWlP8tYXWu2JLiMk=";
          };
    
          format = "wheel";
    
          propagatedBuildInputs = [ pytorch-180 scipy ];
          # buildInputs = [ pybind11 ];
          # nativeBuildInputs = [ pytest-runner pkgs.which ];
    
          doCheck = false;
    
          postInstall = ''
            rm -rf $out/${python.sitePackages}/test
          '';
        };
      scatter = with python.pkgs;
        buildPythonPackage rec {
          pname = "torch_scatter";
          version = "2.0.7";
    
          src = pkgs.fetchurl {
            name = "${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            url =
              "https://pytorch-geometric.com/whl/torch-1.8.0+cpu/${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            hash = "sha256-MRoFretgyEpq+7aJZc0399Kd+f28Uhn5+CxW5ZIKwcg=";
          };
    
          format = "wheel";
    
          propagatedBuildInputs = [ pytorch-180 ];
    
          doCheck = false;
    
          postInstall = ''
            rm -rf $out/${python.sitePackages}/test
          '';
        };
      cluster = with python.pkgs;
        buildPythonPackage rec {
          pname = "torch_cluster";
          version = "1.5.9";
    
          src = pkgs.fetchurl {
            name = "${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            url =
              "https://pytorch-geometric.com/whl/torch-1.8.0+cpu/${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            hash = "sha256-E2nywtiZ7m7VA1J7AY7gAHYvyN9H3zl/W0/WsZLzwF8=";
          };
    
          format = "wheel";
    
          propagatedBuildInputs = [ pytorch-180 ];
    
          doCheck = false;
    
          postInstall = ''
            rm -rf $out/${python.sitePackages}/test
          '';
        };
      spline = with python.pkgs;
        buildPythonPackage rec {
          pname = "torch_spline_conv";
          version = "1.2.1";
    
          src = pkgs.fetchurl {
            name = "${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            url =
              "https://pytorch-geometric.com/whl/torch-1.8.0+cpu/${pname}-${version}-cp38-cp38-linux_x86_64.whl";
            hash = "sha256-ghSzoxoqSccPAZzfcHJEPYySQ/KYqQ90mFsOdt1CjUw=";
          };
    
          format = "wheel";
    
          propagatedBuildInputs = [ pytorch-180 ];
    
          doCheck = false;
    
          postInstall = ''
            rm -rf $out/${python.sitePackages}/test
          '';
        };
      python-louvain = with python.pkgs;
        buildPythonPackage rec {
          pname = "python-louvain";
          version = "0.15";
    
          src = fetchPypi {
            inherit pname version;
            sha256 = "1sqp97fwh4asx0jr72x8hil8z8fcg2xq92jklmh2m599pvgnx19a";
          };
    
          propagatedBuildInputs = [ numpy networkx ];
    
          doCheck = false;
        };
      googledrivedownloader = with python.pkgs;
        buildPythonPackage rec {
          pname = "googledrivedownloader";
          version = "0.4";
    
          src = fetchPypi {
            inherit pname version;
            sha256 = "0172l1f8ys0913wcr16lzx87vsnapppih62qswmvzwrggcrw2d2b";
          };
    
          doCheck = false;
        };
      geometric = with python.pkgs;
        buildPythonPackage rec {
          pname = "torch_geometric";
          version = "1.7.0";
    
          src = fetchPypi {
            inherit pname version;
            sha256 = "1a7ym34ynhk5gb3yc5v4qkmkrkyjbv1fgisrsk0c9xay66w7nwz9";
          };
    
          propagatedBuildInputs = [
            pytorch-180
            numpy
            scipy
            tqdm
            networkx
            scikit-learn
            requests
            pandas
            rdflib
            jinja2
            numba
            ase
            h5py
            python-louvain
            googledrivedownloader
          ];
          nativeBuildInputs = [ pytest-runner ];
    
          doCheck = false;
    
          # postInstall = ''
          #   rm -rf $out/${python.sitePackages}/test
          # '';
        };
      python-with-pkgs = python.withPackages
        (ps: with ps; [ pytorch-180 scatter sparse cluster spline geometric ps ]);
    in pkgs.mkShell { buildInputs = [ python-with-pkgs ]; }