Search code examples
pythonmypypython-typing

mypy arg should be one of a set of functions


from enum import Enum

def strategy_a(x: float, y: float) -> float:
    return x + y

def strategy_b(x: float, y: float) -> float:
    return x * y

class Strategy(Enum):
    A = strategy_a
    B = strategy_b

def run_strategy(x: float, y: float, strategy: Strategy) -> float:
    return strategy(x, y)

Let's say I have something like this where run_strategy's args strategy takes in some possible set of functions. How can I type it such that only those functions can be passed without mypy throwing an error.

Note: the above code throws an error as mypy complains Strategy is not a callable.

The above code is then run as

run_strategy(5, 17, Strategy.A)

Solution

  • A walk-around would be to specify a protocol Strategy that prescribes an implementation of call, and rebrand the Enum item as inheriting that protocol, see example blah.py:

    from enum import Enum
    from typing import Protocol
    
    
    def strategy_a(x: float, y: float) -> float:
        return x + y
    
    
    def strategy_b(x: float, y: float) -> float:
        return x * y
    
    
    class Strategy(Protocol):
        def __call__(self, x: float, y: float) -> float:
            ...
    
    
    class StrategyChoice(Enum, Strategy):
        A = strategy_a
        B = strategy_b
    
    
    def run_strategy(x: float, y: float, strategy: StrategyChoice) -> float:
        return strategy(x, y)
    
    mypy ./blah.py
    Success: no issues found in 1 source file