Search code examples
pythonpython-3.xpython-module

How to import modules containing classes with circular dependency?


Suppose I have the following base and child

class Base:

    def __new__(cls, *args):
        if cls is Base:
            if len(args) < 2:
                return Child1.__new__(Child1, *args)

            return Child2.__new__(Child2, *args)

        return super().__new__(cls)

    def __init__(self, arg):
        self.common_arg = arg


class Child1(Base):
    def __init__(self, arg0=None):
        super().__init__(arg0)



class Child2(Base):
    def __init__(self, arg0, arg1, *args):
        super().__init__(arg0 + arg1)

        self.args = list(args).copy()

There is clearly a circular dependency in between the classes, but, as long as all the classes are defined in the same module this does not cause any problems.

Now, how should I split them into three modules (in the same package)?

I did the split in three files:

package/
    __init__.py
    base.py
    ch1.py
    ch2.py

with the following contents:

# base.py ############################################################

from . import ch1, ch2

class Base:

    def __new__(cls, *args):
        if cls is Base:
            if len(args) < 2:
                return ch1.Child1.__new__(ch1.Child1, *args)

            return ch2.Child2.__new__(ch2.Child2, *args)

        return super().__new__(cls)

    def __init__(self, arg):
        self.common_arg = arg


# ch1.py ############################################################
from . import base

class Child1(base.Base):
    def __init__(self, arg0=None):
        super().__init__(arg0)

# ch2.py ############################################################
from . import base


class Child2(base.Base):
    def __init__(self, arg0, arg1, *args):
        super().__init__(arg0 + arg1)
        self.args = list(args).copy()   

as suggested here but it doesn't work.

import package.ch1

raises

AttributeError: module 'package.base' has no attribute 'Base'

Solution

  • Make your users call a Factory Function:

    def make_base(*args):
        if len(args) < 2:
            return Child1(*args)
    
        return Child2(*args)
    
    
    class Base:
        def __init__(self, arg):
            self.common_arg = arg
    
    
    class Child1(Base):
        pass        # Child1 automatically inherits Base.__init__()
    
    
    class Child2(Base):
        def __init__(self, arg0, arg1, *args):
            super().__init__(arg0 + arg1)
    
            self.args = list(args).copy()
    

    Now each part of the above code can be split into its own file.