Search code examples
pythonscikit-learnfb-hydra

Define argument for hydra instantiate target as power of 2


In using instantiation, is there a way to define a parameter for a target that is a power of 2? For example:

from sklearn.feature_extraction import HashingVectorizer

vec = HashingVectorizer(n_features=2**18)
vec.transform(["a quick fox"])
<1x262144 sparse matrix of type '<class 'numpy.float64'>'
        with 2 stored elements in Compressed Sparse Row format>

As expected the output is a sparse vector with shape (1, 262144) which is equivalent to 2**18.

However, in a config file you cannot use the value 2**18 as it gets passed in as a string.

config.yaml

vec:
  _target_: sklearn.feature_extraction.text.HashingVectorizer
  n_features: 2**18

test.py

import hydra
import hydra.utils as hu


@hydra.main(config_path='conf', config_name='config')
def main(cfg):
    vec = hu.instantiate(cfg.vec)
    vec.transform(['Erroneous Monk'])


if __name__ == "__main__":
    main()

Running this example you get the following:

python test.py
...
TypeError: n_features must be integral, got '2**18' (<class 'str'>).

Is there a way to notify hydra that the value should not be treated as a string?


Solution

  • Arithmetic expressions are not currently supported in OmegaConf (The underlying configuration library). You can implement something using custom resolvers. For example, you could register a custom resolver by the name pow, which will call Python power function on the two inputs.

    import hydra
    import hydra.utils as hu
    from omegaconf import OmegaConf
    
    # register the resolver before you access the config field.
    OmegaConf.register_new_resolver("pow", lambda x,y: x**y)
    
    @hydra.main(config_path='conf', config_name='config')
    def main(cfg):
        vec = hu.instantiate(cfg.vec)
        vec.transform(['Erroneous Monk'])
    
    
    if __name__ == "__main__":
        main()
    

    Your config can be defined as:

    vec:
      _target_: sklearn.feature_extraction.text.HashingVectorizer 
      n_features: ${pow:2,18}