Search code examples
pythonpysparktransformer-modelsimpletransformers

Use of Params in pyspak


In this example, I am trying to use overrides as a Params object and I want it to be used as a list of strings.

But I am not able to assign its value using the below code.

class _AB(Params):

    overrides = Param(Params._dummy(), "overrides", "Parameters for environment setup", typeConverter=TypeConverters.toListString)
    
    def __init__(self, *args):
        super().__init__(*args)
        self._setDefault(overrides=None)
        

class A(_AB):
  @keyword_only
  def __init__(self, overrides):
    super().__init__()
    kwargs = self._input_kwargs
    self.setParams(**kwargs)

  @keyword_only
  def setParams(self, overrides: List[str]):
      kwargs = self._input_kwargs
      print(kwargs)
      return self._set(**kwargs)

  def c(self):
    print(self.overrides.__dict__['typeConverter'].__dict__)
    for i in self.overrides:
       print(i)

a = A(overrides=["dsfs", "Sdf"])
a.c()

It gives me a blank dictionary when I print it inside function c.
It gives me an error:

TypeError: 'Param' object is not iterable

I guess it's happening because it's not able to assign some value to overrides variable.


Solution

  • Param object required to access from a get function else it will report some error. I should access self.overrides from a get function.

    E.g:

    class _AB(Params):
    
        overrides = Param(Params._dummy(), "overrides", "Overrides parameters for environment setup", typeConverter=TypeConverters.toListString)
        
        def __init__(self, *args):
            super().__init__(*args)
            self._setDefault(overrides=None)
        def getoverrides(self):
            return self.getOrDefault(self.overrides)
    
    class A(_AB):
      @keyword_only
      def __init__(self, overrides):
        super().__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)
    
      @keyword_only
      def setParams(self, overrides: List[str]):
          kwargs = self._input_kwargs
          return self._set(**kwargs)
    
      def c(self):
        overrides = self.getoverrides()
        for i in overrides:
           print(i)
    
    a = A(overrides=["Alpha", "Beta"])
    a.c()
    

    It gives:

    Alpha
    Beta