Search code examples
pythonpython-3.xinheritancestrategy-pattern

Can not manage to call child methods in a strategy pattern


I am stuck on a strategy pattern implementation.

I would like the child classes methods implementing the strategy pattern to be called, but it seems that it is only the abstract class method that is called.

from abc import abstractmethod

class TranslationStrategy:

    @classmethod
    @abstractmethod
    def translate_in_french(cls, text: str) -> str:
        pass

    @classmethod
    @abstractmethod
    def translate_in_spanish(cls, text: str) -> str:
        pass

    FRENCH = translate_in_french
    SPANISH = translate_in_spanish


class Translator(TranslationStrategy):

    @abstractmethod
    def __init__(self, strategy = TranslationStrategy.FRENCH):
        self.strategy = strategy

    def get_translation(self):
        print(self.strategy("random string"))

class LiteraryTranslator(Translator):

    def __init__(self, strategy = TranslationStrategy.FRENCH):
        super().__init__(strategy)

    def translate_in_french(self, text: str) -> str:
        return "french_literary_translation"

    def translate_in_spanish(self, text: str) -> str:
        return "spanish_literary_translation"


class TechnicalTranslator(Translator):

    def __init__(self, strategy=TranslationStrategy.FRENCH):
        super().__init__(strategy)

    def translate_in_french(self, text: str) -> str:
        return "french_technical_translation"

    def translate_in_spanish(self, text: str) -> str:
        return "spanish_technical_translation"

translator = TechnicalTranslator(TranslationStrategy.FRENCH)
translator.get_translation() # prints None, I expect "french_technical_translation"

Am I missusing the strategy pattern here ?


Solution

  • I am not familiar with the strategy pattern, but to make your code work, you could use something like the following:

    from abc import ABC, abstractmethod
    from enum import Enum
    
    
    class TranslationStrategy(str, Enum):
        FRENCH = 'french'
        SPANISH = 'spanish'
        
        @classmethod
        def _missing_(cls, value):
            if isinstance(value, str):
                try:
                    return cls._member_map_[value.upper()]
                except KeyError:
                    pass
            return super()._missing_(value)
    
    
    class Translator(ABC):
        def __init__(self, strategy=TranslationStrategy.FRENCH):
            self._strategy = TranslationStrategy(strategy)
            self.strategy = getattr(self, f'translate_in_{self._strategy}')
    
        @abstractmethod
        def translate_in_french(cls, text: str) -> str:
            pass
    
        @abstractmethod
        def translate_in_spanish(cls, text: str) -> str:
            pass
        
        def get_translation(self, text: str = 'random string'):
            print(self.strategy(text))
    
    
    class LiteraryTranslator(Translator):
        def __init__(self, strategy=TranslationStrategy.FRENCH):
            super().__init__(strategy)
    
        def translate_in_french(self, text: str) -> str:
            return "french_literary_translation"
    
        def translate_in_spanish(self, text: str) -> str:
            return "spanish_literary_translation"
    
    
    class TechnicalTranslator(Translator):
        def __init__(self, strategy=TranslationStrategy.FRENCH):
            super().__init__(strategy)
    
        def translate_in_french(self, text: str) -> str:
            return "french_technical_translation"
    
        def translate_in_spanish(self, text: str) -> str:
            return "spanish_technical_translation"
    

    Note that the __init__ method is NOT an abstract method. That method should never be marked as an abstract method, and any method for which you rely on the implementation should not be marked as abstract.

    The self._strategy = TranslationStrategy(strategy) line will ensure that the given strategy is a member of that enum. That is, it will automatically normalize input, and reject invalid values:

    >>> TranslationStrategy('French')
    <TranslationStrategy.FRENCH: 'french'>
    
    >>> TranslationStrategy('french')
    <TranslationStrategy.FRENCH: 'french'>
    
    >>> TranslationStrategy('FRENCH')
    <TranslationStrategy.FRENCH: 'french'>
    
    >>> TranslationStrategy(TranslationStrategy.FRENCH)
    <TranslationStrategy.FRENCH: 'french'>
    
    >>> TranslationStrategy('foo')
    Traceback (most recent call last):
    ...
    ValueError: 'foo' is not a valid TranslationStrategy
    

    In order to properly obtain a reference to a subclass' method, a reference to it must be stored once the subclass can be known. The self.strategy = getattr(self, f'translate_in_{self._strategy}') line stores a reference to the translate_in_french or translate_in_spanish method in the current object, which will be the one defined in the class you initialize. The reason the approach you used did not work was that it stored a reference to the abstract TranslationStrategy.translate_in_french or TranslationStrategy.translate_in_spanish method, not the one defined in the subclass.

    Technically, the __init__ implementations in LiteraryTranslator and TechnicalTranslator are not strictly necessary here since they don't do anything except call super().__init__(...) with the same arguments as the parent class.

    Lastly, stacking @classmethod with @abststractmethod results in an abstract class method, not an abstract instance (normal) method. Since these are intended to be normal methods, the @classmethod had to be omitted.