Search code examples
pythonpython-typingpyright

pyright prioritizes the number of passed arguments over their type when using @overload


I have a base class _BaseThing, as well as its subclasses CubeThing and BallThing. The _BaseThing class should never be explicitly created in the program, so it is prefixed with _.

I also have a factory class Thing that creates a specific class (CubeThing or BallThing) depending on the passed thing_type.

It is important that in addition to the common parameters, each Thing implementation has its own additional unique parameters (cube has an edge length, sphere has a radius and some_other_unique_param).

from __future__ import annotations
from typing import overload, Literal
from enum import Enum


class ThingType(Enum):
    CUBE = 0
    BALL = 1


class _BaseThing:
    """
    Base class for all Things.
    It shouldn't be used in real program, so it prefixed with _.
    """
    def __init__(self, thing_type: ThingType, color, material):
        self.type = thing_type
        self.color = color
        self.material = material

    def change_color(self, color): ...

    def change_material(self, material): ...


class CubeThing(_BaseThing):
    """
    Cubical implementation of the Thing.
    """
    def __init__(self, color, material, edge_length: float):
        super().__init__(ThingType.CUBE, color, material)
        self.edge_length = edge_length

    def some_cube_method(self):
        print('This is cube!')


class BallThing(_BaseThing):
    """
    Spherical implementation of the Thing.
    """
    def __init__(self, color, material, radius: float, some_other_unique_param: int):
        super().__init__(ThingType.BALL, color, material)
        self.radius = radius
        self.some_other_unique_param = some_other_unique_param

    def some_ball_method(self):
        print('This is ball!')


class Thing:
    """Class, that looks like base class for all Things (bc of the name), but it is actually a 'fabric'."""
    __thing_types_dict__ = {
        ThingType.CUBE: CubeThing,
        ThingType.BALL: BallThing,
    }
    @overload
    def __new__(cls, thing_type: Literal[ThingType.CUBE], color, material, edge_length: float) -> CubeThing:
        ...

    @overload
    def __new__(cls, thing_type: Literal[ThingType.BALL], color, material,
                radius: float, some_other_unique_param: int) -> BallThing:
        ...

    def __new__(cls, thing_type: ThingType, color, material, *args, **kwargs):
        if thing_type not in cls.__thing_types_dict__:
            raise TypeError('Unexpected thing_type.')
        
        return cls.__thing_types_dict__[thing_type](color, material, *args, **kwargs)

When I try to call ball = Thing(ThingType.BALL, 'blue', 'silicon', 1), the PyCharm static analyzer (as well as pyright) informs me that I passed an invalid thing_type argument (importantly, in this example I forgot to write the last argument). That is, pyright has already "selected" @overload with CubeThing. However, as soon as I add the last argument, it immediately “switches” to @overload with BallThing and the warning disappears. So, PyCharm/pyright does not prioritize the types I passed, but the number of arguments passed. enter image description here

However, I'd like to see that as soon as I enter the first argument (thing_type), static analyzer immediately identifies the right @overload and further writes warnings based on it. That is, in the case of the above example, I want it to write not that I passed an invalid type, but that I passed insufficient arguments.

I thought that maybe the problem was that I was overloading the __new__ method, so I tried to code a functional implementation:

@overload
def create_thing(thing_type: Literal[ThingType.CUBE], color, material, edge_length: float) -> CubeThing:
    ...

@overload
def create_thing(thing_type: Literal[ThingType.BALL], color, material, 
                 radius: float, some_other_unique_param: int) -> BallThing:
    ...


def create_thing(thing_type: ThingType, color, material, *args, **kwargs):
    thing_types_dict = {
        ThingType.CUBE: CubeThing,
        ThingType.BALL: BallThing,
    }

    if thing_type not in thing_types_dict:
        raise TypeError('Unexpected thing_type.')

    return thing_types_dict[thing_type](color, material, *args, **kwargs)

However, this variant has exactly the same problem. enter image description here

I also tried using the strings '0' and '1' in the Literal arguments and typehints instead of ThingType, but it also didn't help.


Solution

  • So, PyCharm/pyright does not prioritize the types I passed, but the number of arguments passed.

    That is correct. Pyright looks at the number of arguments (as well as the existence of keyword arguments) to filter the overloads first and foremost. PyCharm probably has a similar strategy.

    To quote Pyright's documentation:

    PEP 484 introduced the @overload decorator and described how it can be used, but the PEP did not specify precisely how a type checker should choose the “best” overload. Pyright uses the following rules.

    1. Pyright first filters the list of overloads based on simple “arity” (number of arguments) and keyword argument matching. [...]
    2. Pyright next considers the types of the arguments and compares them to the declared types of the corresponding parameters. [...]
    3. If only one overload remains, it is the “winner”.
    4. If more than one overload remains, the “winner” is chosen based on the order in which the overloads are declared. In general, the first remaining overload is the “winner”. [...]

    Static Typing: Advanced Topics § Overloads | Pyright's documentation

    This strategy does not always produce "optimal" outputs in case of mismatching, but, as explained, it was chosen in the absence of a standardized one.

    (It should be emphasized that having a standardized strategy does not mean the outputs will be "optimal", just that all tools will fail in the same way.)

    As it happens, there are ongoing efforts to standardize this.