Search code examples
pythongraphormsqlalchemymany-to-many

How can I write a SQLAlchemy query that will return all descendants of a node in a graph?


I am working on an application where my database objects often have multiple parents and multiple children, and would like to create a SQLAlchemy query that will return all descendants of an object.

Realizing that I am basically trying to store a graph in a SQL database, I found that setting up a self-referential many-to-many schema got me most of the way there, but I am having trouble writing the query to return all descendants of a node. I tried to follow SQLA's recursive CTE example, which looks like the right approach, but have been running into problems getting it to work. I think my situation is different from the example because in my case, queries to Node.child (and Node.parent) return instrumented lists and not ORM objects.

In any case, the code below will set up a simple directed acyclic disconnected graph that looks like this (where the direction is inferred to be from the higher row to the lower one):

a   b    c
 \ / \   |
  d   e  f
  |\ /
  g h     
  |
  i

And what I'm looking for is some help writing a query that will give me all descendants of a node.

  • get_descendants(d) should return g, h, i

  • get_descendants(b) should return d, e, g, h, i

Example code:

from sqlalchemy.orm import aliased

from sqlalchemy import Column, ForeignKey, Integer, Table, Text
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker

engine = create_engine('sqlite:///:memory:', echo=True)
Session = sessionmaker(bind=engine)

session = Session()

Base = declarative_base()

association_table = Table('association_table', Base.metadata,
                           Column('parent_id', Integer, ForeignKey('node.id'), primary_key=True),
                           Column('child_id', Integer, ForeignKey('node.id'), primary_key=True))


class Node(Base):
    __tablename__ = 'node'
    id = Column(Integer, primary_key=True)
    property_1 = Column(Text)
    property_2 = Column(Integer)

    # http://docs.sqlalchemy.org/en/latest/orm/join_conditions.html#self-referential-many-to-many-relationship
    child = relationship('Node',
                            secondary=association_table,
                            primaryjoin=id==association_table.c.parent_id,
                            secondaryjoin=id==association_table.c.child_id,
                            backref='parent'
                            )

Base.metadata.create_all(engine)

a = Node(property_1='a', property_2=1)
b = Node(property_1='b', property_2=2)
c = Node(property_1='c', property_2=3)
d = Node(property_1='d', property_2=4)
e = Node(property_1='e', property_2=5)
f = Node(property_1='f', property_2=6)
g = Node(property_1='g', property_2=7)
h = Node(property_1='h', property_2=8)
i = Node(property_1='i', property_2=9)



session.add_all([a, b, c, d, e, f, g, h, i])
a.child.append(d)
b.child.append(d)
d.child.append(g)
d.child.append(h)
g.child.append(i)
b.child.append(e)
e.child.append(h)
c.child.append(f)

session.commit()
session.close()

Solution

  • Solution

    The following, surprisingly simple, self-referential many-to-many recursive CTE query will return the desired results for finding all descendants of b:

    nodealias = aliased(Node)
    
    descendants = session.query(Node)\
        .filter(Node.id == b.id) \
        .cte(name="descendants", recursive=True)
    
    descendants = descendants.union(
        session.query(nodealias)\
        .join(descendants, nodealias.parent)
    )
    

    Testing with

    for item in session.query(descendants):
        print(item.property_1, item.property_2)
    

    Yields:

    b 2
    d 4
    e 5
    g 7
    h 8
    i 9
    

    Which is the correct list of b and all of its descendants.

    Full working example code

    This example adds a convenient function to the Node class for returning all descendants of an object, while also computing the path from itself to all of its descendants:

    from sqlalchemy.orm import aliased
    from sqlalchemy import Column, ForeignKey, Integer, Table, Text
    from sqlalchemy import create_engine
    from sqlalchemy.ext.declarative import declarative_base
    from sqlalchemy.orm import relationship
    from sqlalchemy.orm import sessionmaker
    
    engine = create_engine('sqlite://', echo=True)
    Session = sessionmaker(bind=engine)
    
    session = Session()
    
    Base = declarative_base()
    
    association_table = Table('association_table', Base.metadata,
                               Column('parent_id', Integer, ForeignKey('node.id'), primary_key=True),
                               Column('child_id', Integer, ForeignKey('node.id'), primary_key=True))
    
    
    class Node(Base):
        __tablename__ = 'node'
        id = Column(Integer, primary_key=True)
        property_1 = Column(Text)
        property_2 = Column(Integer)
    
        # http://docs.sqlalchemy.org/en/latest/orm/join_conditions.html#self-referential-many-to-many-relationship
        child = relationship('Node',
                                secondary=association_table,
                                primaryjoin=id==association_table.c.parent_id,
                                secondaryjoin=id==association_table.c.child_id,
                                backref='parent'
                                )
    
        def descendant_nodes(self):
            nodealias = aliased(Node)
            descendants = session.query(Node.id, Node.property_1, (self.property_1 + '/' + Node.property_1).label('path')).filter(Node.parent.contains(self))\
                .cte(recursive=True)
            descendants = descendants.union(
                session.query(nodealias.id, nodealias.property_1, (descendants.c.path + '/' + nodealias.property_1).label('path')).join(descendants, nodealias.parent)
            )
            return session.query(descendants.c.property_1, descendants.c.path).all()
    
    
    Base.metadata.create_all(engine)
    
    a = Node(property_1='a', property_2=1)
    b = Node(property_1='b', property_2=2)
    c = Node(property_1='c', property_2=3)
    d = Node(property_1='d', property_2=4)
    e = Node(property_1='e', property_2=5)
    f = Node(property_1='f', property_2=6)
    g = Node(property_1='g', property_2=7)
    h = Node(property_1='h', property_2=8)
    i = Node(property_1='i', property_2=9)
    
    
    
    session.add_all([a, b, c, d, e, f, g, h, i])
    a.child.append(d)
    b.child.append(d)
    d.child.append(g)
    d.child.append(h)
    g.child.append(i)
    b.child.append(e)
    e.child.append(h)
    c.child.append(f)
    e.child.append(i)
    
    session.commit()
    
    
    for item in b.descendant_nodes():
        print(item)
    
    session.close()
    
    
    """
    Graph should be setup like this:
    
    a   b    c
     \ / \   |
      d   e  f
      |\ /|
      g h |    
      +---+
      i
    
    """
    

    Output:

    ('d', 'b/d')
    ('e', 'b/e')
    ('g', 'b/d/g')
    ('h', 'b/d/h')
    ('h', 'b/e/h')
    ('i', 'b/e/i')
    ('i', 'b/d/g/i')
    

    Comments