Search code examples
pythonpython-typingpyright

How to overload class member type depending on argument?


I recently learned I can overload return type using Literal[True] and Literal[False] argument. I am implementing my own subprocess.Popen-ish interface and I am not able to overload self.stdin return type to be IO[bytes] or IO[str] depending on the value of text in the constructor. I am using pyright for static type checking.

I have tried the following:

from typing import IO, Literal, Optional, overload


class MyPopen:
    @overload
    def __init__(self, text: Literal[False] = False):
        self.stdin: Optional[IO[bytes]]

    @overload
    def __init__(self, text: Literal[True] = True):
        self.stdin: Optional[IO[str]]

    def __init__(self, text: bool = False):
        self.stdin = None


pp = MyPopen(text=True)
assert pp.stdin
pp.stdin.write("text")    # should be ok
pp.stdin.write(b"text")   # should error
pp = MyPopen(text=False)
assert pp.stdin
pp.stdin.write("text")    # should error
pp.stdin.write(b"text")   # should be ok

However, it assumes pp.stdin to be IO[str] twice, and the declaration stdin is obscured:

$ pyright /dev/stdin <1.py
/dev/stdin
  /dev/stdin:7:14 - error: Declaration "stdin" is obscured by a declaration of the same name (reportRedeclaration)
  /dev/stdin:20:1 - error: No overloads for "write" match the provided arguments (reportCallIssue)
  /dev/stdin:20:16 - error: Argument of type "Literal[b"text"]" cannot be assigned to parameter "__s" of type "str" in function "write"
    "Literal[b"text"]" is incompatible with "str" (reportArgumentType)
  /dev/stdin:24:1 - error: No overloads for "write" match the provided arguments (reportCallIssue)
  /dev/stdin:24:16 - error: Argument of type "Literal[b"text"]" cannot be assigned to parameter "__s" of type "str" in function "write"
    "Literal[b"text"]" is incompatible with "str" (reportArgumentType)
5 errors, 0 warnings, 0 informations 

How can I overload class instance member type depending on the value of argument?

Note how it works with subprocess.Popen:

import subprocess
pp = subprocess.Popen("", text=True)
assert pp.stdout
pp.stdout.write("text")   # ok
pp.stdout.write(b"text")  # error
pp = subprocess.Popen("", text=False)
assert pp.stdout
pp.stdout.write("text")   # error
pp.stdout.write(b"text")  # ok

I tried reading subprocess source code https://github.com/python/cpython/blob/main/Lib/subprocess.py#L1015 , but I do not know where type annotations are stored.


Solution

  • I don't think this is valid use of @overload. It is used to specify more precise signatures for functions, not to conditionally set the type of attributes.

    However, you can create two empty subclasses, one for str and one for bytes, then let MyPopen.__new__() return either of those and use @overload on that instead. The subclasses don't have to do anything other than serving as types; the logic will still be defined in MyPopen.

    (playgrounds: mypy, Pyright)

    # Or T = TypeVar('T', str, bytes); class MyPopen(Generic[T]): ...
    class MyPopen[T: (str, bytes)]:
    
      stdin: IO[T] | None
    
      # Now you can use @overload!
    
      @overload
      def __new__(cls, text: Literal[False] = ...) -> _BytesPopen:
        ...
    
      @overload
      def __new__(cls, text: Literal[True]) -> _StrPopen:
        ...
    
      def __new__(cls, text: bool = False) -> _BytesPopen | _StrPopen:
        if text is True:
          subclass = _StrPopen
        else:
          subclass = _BytesPopen
    
        # Use `object.__new__()` or `type: ignore` to pass strict mode
        instance = super().__new__(subclass)
        instance.stdin = None
    
        return instance
    
      def write(self, output: T) -> None:
        ...
    
    class _StrPopen(MyPopen[str]):
      pass
    
    class _BytesPopen(MyPopen[bytes]):
      pass
    

    The results will then be:

    pp = MyPopen()
    assert pp.stdin
    pp.stdin.write("text")    # error: "Literal['text']" is incompatible with "bytes"
    pp.stdin.write(b"text")   # fine
    
    pp = MyPopen(text=False)
    assert pp.stdin
    pp.stdin.write("text")    # error: "Literal['text']" is incompatible with "bytes"
    pp.stdin.write(b"text")   # fine
    
    pp = MyPopen(text=True)
    assert pp.stdin
    pp.stdin.write("text")    # fine
    pp.stdin.write(b"text")   # error: "Literal[b"text"]" is incompatible with "str"