Search code examples
pythonfastapipython-typingpydantic

N-dim numerical list type


How does one define a type for N-dimensional numerical list (tensor) in Python 3.7? This would be used as one of the props for Pydantic's BaseModel.

I'd like something like

from typing import List, Union

NumericalList = Union[
    int, float,
    List[int], List[float],
    List[List[int]], List[List[float]],
    ...
]
n1: NumericalList = [0]
n3: NumericalList = [ [ [0, 1, 2], [1, 1, 2] ],
                      [ [1, 2, 3], [0, 1, 3] ],
                    ]

I know that on classes one can write string literals to indicate children to be of the same type. (Or just adding from __future__ import annotations.) I'd like an iterable/sliceable, rather than accessing by props.

Thought that maybe recursive definition would work but it fails with AttributeError: __forward_arg__ after many levels of deepcopy by typing module.

NumericalList = Union[int, float, List["NumericalList"]]  # AttributeError: __forward_arg__

Note, however, that this fails when using Pydantic rather than in Python's IDE. Is that something specific to Pydantic or that's a wrong way to go about?


Solution

  • While I wish pydantic support native recursive types, you can use pydantic custome root types models with pydantic strict types to make sure float values do not become int

    from __future__ import annotations
    from typing import Union, List
    from pydantic import BaseModel, StrictInt, StrictFloat
    
    class NumericalList(BaseModel):
        __root__: Union[StrictInt, StrictFloat, List[NumericalList]]
    
    
    NumericalList.update_forward_refs()
    
    
    n1: NumericalList = NumericalList.parse_obj([0])
    """
    NumericalList(__root__=[NumericalList(__root__=0)])
    """
    
    n2: NumericalList = NumericalList.parse_obj(
                        [ [ [0, 1, 2], [1, 1, 2] ],
                          [ [1, 2, 3], [0, 1, 3] ],
                        ]
                       )
    """
    NumericalList(__root__=[
        NumericalList(__root__=[
            NumericalList(__root__=[NumericalList(__root__=0), NumericalList(__root__=1), NumericalList(__root__=2)]),
            NumericalList(__root__=[NumericalList(__root__=1), NumericalList(__root__=1), NumericalList(__root__=2)])
        ]),
        NumericalList(__root__=[
            NumericalList(__root__=[NumericalList(__root__=1), NumericalList(__root__=2), NumericalList(__root__=3)]),
            NumericalList(__root__=[NumericalList(__root__=0), NumericalList(__root__=1), NumericalList(__root__=3)])
        ])
    ])
    
    """
    
    
    n1.dict()
    """
    {"__root__": [0]}
    """
    
    n2.dict()
    """
    {"__root__": [ 
                      [ [0, 1, 2], [1, 1, 2] ],
                      [ [1, 2, 3], [0, 1, 3] ],
                 ]}
    """
    
    

    you can extend the class to cover list functions but this is optional

    class NumericalList(BaseModel):
        __root__: Union[StrictInt, StrictFloat, List[NumericalList]]
    
        def __iter__(self):
            return iter(self.__root__)
    
        def __getitem__(self, index):
            return self.__root__[index]
    
        def __setitem__(self, index, value):
            self.__root__[index] = value
    
    
    NumericalList.update_forward_refs()
    
    n3: NumericalList = NumericalList.parse_obj([0, 5, [1]])
    for i in n3:
        print(i)
    """
    __root__=0
    __root__=5
    __root__=[NumericalList(__root__=1)]
    """
    

    If you want to get the native types you can do

    n3.dict()
    """
    {'__root__': [0, 5, [1]]}
    """