Search code examples
pythonneo4jquery-optimization

Python/Neo4j Query Optimization


I have run out of ideas trying to find the cause for such low write speeds. My background is relational databases so I might be doing something wrong. To add 10 nodes and 45 connections I currently need 1.4 seconds (empty DB). This is unacceptable and in my opinion should be order of milliseconds if even.

Requirement

Create method to add snapshots to the Neo4j database. One snapshot consists of 10 nodes (all have different labels and properties). I need to connect all nodes in this snapshot, unidirectionally, and without recursive connections. This equates to 10 nodes 45 connections per snapshot. Relationships are created with property strength = 1. Every time I add a new relationship, if it already exists (meaning match nodeA(oHash) -> nodeB(oHash)) I just increment the strength instead of having duplicates.

Measurements

I compensated for all overhead with regards to the API, Python itself, etc. Currently over 99.9% of the execution time is from querying Neo4j. I observed that generating nodes seems to be much slower than generating the connections (about ~90% of the total time).

To generate one snapshot (10 nodes, 45 connections) into an empty database my query (from Python) takes 1.4 seconds when averaged on 100 runs.

Indexes

In the code I am going to show in the post, below, you will find a create constraints method that I never call. This is because I already created the indexes on all node/label types and I removed the call to it to reduce overhead of checking on existing indexes. Every node has an "oHash" property which is an MD5 hash of the json of all properties excluding internal Neo4j <ID>. This uniquely identifies my nodes so I created a UNIQUE constraint on "oHash". As far as I understand creating a UNIQUE constraint also creates an index on that property in Neo4j.

Best Practices

I used all recommended best practices I could find online. These include:

  1. Creating a single driver instance and reusing it
  2. Creating a single driver session and reusing it
  3. Using explicit transactions
  4. Using query parameters
  5. Creating a batch and executing as a single transaction

Implementation

Here is my current implementation:

import json
import hashlib
import uuid
from neo4j import GraphDatabase

class SnapshotRepository:
    """A repository to handle snapshots in a Neo4j database."""

    def __init__(self):
        """Initialize a connection to the Neo4j database."""
        with open("config.json", "r") as file:
            config = json.load(file)
        self._driver = GraphDatabase.driver(
            config["uri"], auth=(config["username"], config["password"])
        )
        self._session = self._driver.session()

    def delete_all(self):
        """Delete all nodes and relationships from the graph."""
        self._session.run("MATCH (n) DETACH DELETE n")

    def add_snapshot(self, data):
        """
        Add a snapshot to the Neo4j database.
        
        Args:
            data (dict): The snapshot data to be added.
        """
        snapshot_id = str(uuid.uuid4())  # Generate a unique snapshot ID
        self._session.execute_write(self._add_complete_graph, data, snapshot_id)

    def _create_constraints(self, tx, labels):
        """
        Create uniqueness constraints for the specified labels.
        
        Args:
            tx (neo4j.Transaction): The transaction to be executed.
            labels (list): List of labels for which to create uniqueness constraints.
        """
        for label in labels:
            tx.run(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:{label}) REQUIRE n.oHash IS UNIQUE")

    @staticmethod
    def _calculate_oHash(node):
        """
        Calculate the oHash for a node based on its properties.

        Args:
            node (dict): The node properties.

        Returns:
            str: The calculated oHash.
        """
        properties = {k: v for k, v in node.items() if k not in ['id', 'snapshotId', 'oHash']}
        properties_json = json.dumps(properties, sort_keys=True)
        return hashlib.md5(properties_json.encode('utf-8')).hexdigest()

    def _create_or_update_nodes(self, tx, nodes, snapshot_id):
        """
        Create or update nodes in the graph.

        Args:
            tx (neo4j.Transaction): The transaction to be executed.
            nodes (list): The nodes to be created or updated.
            snapshot_id (str): The ID of the snapshot.
        """
        for node in nodes:
            node['oHash'] = self._calculate_oHash(node)
            node['snapshotId'] = snapshot_id
            tx.run("""
                MERGE (n:{0} {{oHash: $oHash}})
                ON CREATE SET n = $props
                ON MATCH SET n = $props
            """.format(node['label']), oHash=node['oHash'], props=node)

    def _create_relationships(self, tx, prev, curr):
        """
        Create relationships between nodes in the graph.

        Args:
            tx (neo4j.Transaction): The transaction to be executed.
            prev (dict): The properties of the previous node.
            curr (dict): The properties of the current node.
        """
        if prev and curr:
            oHashA = self._calculate_oHash(prev)
            oHashB = self._calculate_oHash(curr)
            tx.run("""
                MATCH (a:{0} {{oHash: $oHashA}}), (b:{1} {{oHash: $oHashB}})
                MERGE (a)-[r:HAS_NEXT]->(b)
                ON CREATE SET r.strength = 1
                ON MATCH SET r.strength = r.strength + 1
            """.format(prev['label'], curr['label']), oHashA=oHashA, oHashB=oHashB)

    def _add_complete_graph(self, tx, data, snapshot_id):
        """
        Add a complete graph to the Neo4j database for a given snapshot.

        Args:
            tx (neo4j.Transaction): The transaction to be executed.
            data (dict): The snapshot data.
            snapshot_id (str): The ID of the snapshot.
        """
        nodes = data['nodes']
        self._create_or_update_nodes(tx, nodes, snapshot_id)
        tx.run("""
            MATCH (a {snapshotId: $snapshotId}), (b {snapshotId: $snapshotId})
            WHERE a.oHash < b.oHash
            MERGE (a)-[r:HAS]->(b)
            ON CREATE SET r.strength = 1, r.snapshotId = $snapshotId
            ON MATCH SET r.strength = r.strength + 1
        """, snapshotId=snapshot_id)
        self._create_relationships(tx, data.get('previousMatchSnapshotNode', None), data.get('currentMatchSnapshotNode', None))

All input and suggestions are welcome.


Solution

  • I have found the solution. For everyone who might stumble across this thread here are two things that are wrong with the code.

    Firstly, I never close the session. Since I was dealing with GCP Functions I completely forgot about this. I added it to the class destructor:

    def __del__(self):
        """Close the connection to the Neo4j database."""
        self._session.close()
        self._driver.close()
    

    Secondly, I read in the Neo4j Performance Recommendation docs: https://neo4j.com/docs/python-manual/current/performance/ that not setting the DB name when creating the driver can cause a significant overhead when many queries are being executed which was my situation.

    self._db_name = "neo4j"
    self._session = self._driver.session(database=self._db_name)
    

    What I found in the profiler when running EXPLAIN is that DB-lookup is a major part in most queries. I usually set the DB name in an SQL-connector and never think twice about it.