Search code examples
pythonrepository-pattern

Link pairs of type bounds together in Python


I am trying to create an abstract repository class so that I can reduce drastically the length of my boilerplate code. Meanwhile, to do so, I would need to link the bounds of Entity to a certain Domain (Like when using Pick in Typescript) in order to raise mypy attention such Missing positional arguments. Is it possible to do so in Python ?

"""
Generic Repository
"""
from src.Managers.session import SessionManager
from typing import Generic, TypeVar, Union
from sqlalchemy.orm import Query
from src.User.models.TermsOfUse import TermsOfUseModel, TermsOfUse
from src.User.models.User import UserModel, User

entity_bound = Union[UserModel, TermsOfUseModel]
domain_bound = Union[User, TermsOfUse]

Entity = TypeVar('Entity', bound=entity_bound)
Domain = TypeVar('Domain', bound=domain_bound)
SpecificException = TypeVar('SpecificException', bound=Exception)


class RepoPattern(Generic[Entity, Domain, SpecificException]):
    """
    Repository Pattern.
    """

    def __init__(
        self,
        entity: type[Entity],
        domain: type[Domain],
        exception: type[SpecificException],
        session_context: SessionManager
            ):

        self._entity = entity
        self._domain = domain
        self._exception = exception
        self._session_context = session_context

    @property
    def session_context(self) -> SessionManager:
        return self._session_context

    def add(self, domain: Domain) -> None:
        """
        Add a new domain object.

        Parameters:
        ----
            :param domain: domain object.

        Returns:
        ----
            Domain
        """
        entity = self._get_entity_from_domain(domain)
        self._session_context.add(entity)

    def modify(self, domain: Domain) -> None:
        """
        Add a new domain object.

        Parameters:
        ----
            :param domain: domain object.

        Returns:
        ----
            Domain
        """
        assert domain.id

        query = self._find_by_id_query(domain.id)
        entity = self._find_first_record(query)
        entity.reset(domain.to_dict())
        self._session_context.modify(entity)

    def remove(self, domain: Domain) -> None:
        """
        Remove a domain object.

        Parameters:
        ----
            :param domain: domain object.

        Returns:
        ----
            Domain
        """
        assert domain.id

        query = self._find_by_id_query(domain.id)
        entity = self._find_first_record(query)
        self._session_context.remove(entity)

    def find_by_id(self, id: int) -> Domain:
        """
        Find a domain object by id.

        Parameters:
        ----
            :param id: id of the entity.

        Returns:
        ----
            Domain
        """
        query = self._find_by_id_query(id)
        return self._find_first_domain(query)

    def _find_by_id_query(self, id: int) -> Query[Entity]:
        """
        Find a domain object by id.

        Parameters:
        ----
            :param id: id of the entity.

        Returns:
        ----
            Query[Entity]
        """
        query = self._entity.query.filter_by(id=id)
        return query

    def _find_first_domain(self, query: Query[Entity]) -> Domain:
        """
        Find the first domain object.

        Parameters:
        ----
            :param query: query.

        Returns:
        ----
            Domain
        """
        record = self._find_first_record(query)
        domain = self._get_domain_from_entity(record)
        return domain

    def _find_all(self, query: Query[Entity]) -> list[Domain]:
        """
        Find all domain objects.

        Parameters:
        ----
            :param query: query.

        Returns:
        ----
            list[Domain]
        """
        records = self._find_all_records(query)
        return [
            self._get_domain_from_entity(record) for record in records
            ]

    def _find_first_record(self, query: Query[Entity]) -> Entity:
        """
        Find the first user record.

        Parameters:
        ----
            :param query: query.

        Returns:
        ----
            Entity
        """

        record = query.first()

        if not record:
            raise self._exception()

        return record

    def _find_all_records(self, query: Query[Entity]) -> list[Entity]:
        """
        Find all records.

        Parameters:
        ----
            :param query: query.

        Returns:
        ----
            list[Entity]
        """

        records = query.all()
        return records

    def _get_domain_from_entity(self, entity: Entity) -> Domain:
        """
        Get the entity to domain.

        Parameters:
        ----
            :param entity: entity used.

        Returns:
        ----
            Domain
        """
        return self._domain.from_dict(entity.to_dict())

    def _get_entity_from_domain(self, domain: Domain) -> Entity:
        """
        Get the domain to entity.

        Parameters:
        ----
            :param domain: domain used.

        Returns:
        ----
            Entity
        """

        return self._entity(**domain.to_dict())

Edit: I want mypy to raise these kind of errors: Argument 1 to "reset" of "UserModel" has incompatible type "TermsOfUseDomainDict"; expected "UserDomainDict" [arg-type] mypy(error) while instantiating the repository pattern like that:

repo = RepoPattern(
   UserModel,
   TermsOfUse,
   UserNotFoundException,
   SessionManager()
)

Meanwhile I cannot bound my generic types using a simple Union as:

entity_bound = Union[UserModel, TermsOfUseModel]
domain_bound = Union[User, TermsOfUse]

What I need is picking a domain type bound given an entity type. For instance: When selecting a UserModel type as an entity type. I would automatically select a User domain type in the repository.


Solution

  • If you want to inherit from the RepoPattern for each entity-domain pair, you can specialize the TypeVars in the derived class. Also, do not hard code the entity_bound and domain_bound in the abstract base class module, and consider working with something like abstract Entity and Domain classes. I also didn't understand why you made the SpecificException type variable since it doesn't appear in the class interface at all. With this idea in mind, the code could be something similar to this:

    from typing import TypeVar, Generic, Type
    from entity import Entity
    from domain import Domain
    
    TEntity = TypeVar("TEntity", bound=Entity)
    TDomain = TypeVar("TDomain", bound=Domain)
    
    class RepositoryBase(Generic[TEntity, TDomain]):
        def __init__(
            self, entity_type: Type[TEntity], domain_type: Type[TDomain]
        ) -> None:
            ...
    
    class UserRepository(RepositoryBase[UserModel, User]):
        ...
    

    But if you prefer to keep your solution, you can overload the __init__ function to only accept specific combinations:

    from __future__ import annotations
    from typing import TypeVar, Generic, Type, overload
    from entity import Entity
    from domain import Domain
    
    TEntity = TypeVar("TEntity", bound=Entity)
    TDomain = TypeVar("TDomain", bound=Domain)
    
    class Repository(Generic[TEntity, TDomain]):
        @overload
        def __init__(
            self: Repository[UserModel, User],
            entity_type: Type[UserModel],
            domain_type: Type[User],
        ) -> None:
            pass
    
        @overload
        def __init__(
            self: Repository[TermOfUseModel, TermOfUse],
            entity_type: Type[TermOfUseModel],
            domain_type: Type[TermOfUse],
        ) -> None:
            pass
    
        def __init__(self, entity_type, domain_type) -> None:
            ...
    
    
    Repository(UserModel, User)
    Repository(TermOfUseModel, TermOfUse)
    Repository(
        UserModel, TermOfUse
    )  # error: Argument 2 to "Repository" has incompatible type "type[TermOfUse]"; expected "type[User]"  [arg-type]
    

    You should avoid annotating the actual __init__ implementation because the union of two overloads does not cover all combinations. Instead, annotate the self variable in the overloads to guide Mypy in deducing TypeVars.

    Some side notes:

    • Avoid importing modules from the src directory, as your imports will not work when the code is installed.
    • I couldn't find any public interface with the Entity type variable usage in your code. Consider removing it from the type variables and using a runtime mapping or function to create or work with entities from the domain or domain type.
    • Mapping between similar or related objects in different layers is not the repository's responsibility. Consider using auto-mappers like py-automapper or object-mapper. Alternatively, you can implement your specific mappers or place the mapping logic in the class on the outer layer.
    • Mypy itself doesn't support Type mappings.
    • Alternatively, you might find the Active Record pattern and its implementations in Python more suitable for your needs.