Search code examples
databasecassandradata-ingestioncassandra-python-driver

Batch insert fails for table with a CQL map and nested UDT


I'm trying to write some data to cassandra, a connection is established and this is my database schema :

CREATE TABLE IF NOT EXISTS road_traffic (
    road_id INT,
    timestamp TIMESTAMP,
    radar_id INT,
    vehicles MAP<INT, FROZEN<UDTVehicle>>,
    PRIMARY KEY (road_id, timestamp) // partition key and clustering key
);

CREATE TYPE IF NOT EXISTS UDTVehicle (
    num_vehicles INT,
    speed LIST<FLOAT>
);

and this is my python code :

def write_to_cassandra(self,session, record):
        insert_query = """
        INSERT INTO road_traffic (road_id, timestamp, radar_id, vehicles)
        VALUES (?, ?, ?, ?);
        """
        prepared_insert = session.prepare(insert_query)
        batch = BatchStatement()
        batch_size = 0
        batches= 0
        print(record)
        if not all(record.get(field) is not None for field in list(record.keys())) :
            logger.warning("Record is missing a required fields")
            return
        else :
            vehicles_map = {record.get('road_id'): record.get('Vehicles')}
            print(record.get('road_id'))
            batch.add(prepared_insert, (record.get('road_id'), record.get('timestamp'), record.get('radar_id'), vehicles_map))
            batch_size += 1

As far as I understand the vehicles_map must be a key(which is the partition key) and a value(my UDTVehicle object). I get this when I print the record :

{'timestamp': datetime.datetime(2022, 1, 1, 8, 0, 30), 'num_Vehicles': 7, 'road_id': 9, 'radar_id': 11, 'Vehicles': {'num_Vehicles': 7, 'speed': [72.78, 62.67, 85.15, 75.51, 83.95, 76.39, 57.92]}}

The error I get is as follows :

Traceback (most recent call last):
  File "cassandra_kafka/ingest_cassandra.py", line 135, in <module>
    Data_ingest.consume_store()
  File "cassandra_kafka/ingest_cassandra.py", line 113, in consume_store
    self.write_to_cassandra(session, value)
  File "cassandra_kafka/ingest_cassandra.py", line 66, in write_to_cassandra
    batch.add(prepared_insert, (record.get('road_id'), record.get('timestamp'), record.get('radar_id'), vehicles_map))
  File "cassandra/query.py", line 827, in cassandra.query.BatchStatement.add
  File "cassandra/query.py", line 506, in cassandra.query.PreparedStatement.bind
  File "cassandra/query.py", line 636, in cassandra.query.BoundStatement.bind
  File "cassandra/cqltypes.py", line 799, in cassandra.cqltypes._ParameterizedType.serialize
  File "cassandra/cqltypes.py", line 909, in cassandra.cqltypes.MapType.serialize_safe
  File "cassandra/cqltypes.py", line 324, in cassandra.cqltypes._CassandraType.to_binary
  File "cassandra/cqltypes.py", line 799, in cassandra.cqltypes._ParameterizedType.serialize
  File "cassandra/cqltypes.py", line 1030, in cassandra.cqltypes.UserType.serialize_safe
KeyError: 0

What am I doing wrong here ? any help would be appreciated ^^.


Solution

  • the proper way to insert rows containing UDTs in the Python driver, with prepared statements, is to use a simple class with the same structure as the UDT and use it when creating the values for the insert statement.

    I have prepared a simple code demonstrating what I mean: check the UdtVehicle class and how it is instantiated when creating the arguments to the insert statement in the batch.

    The sample code goes on to demonstrate a few other things, depending on the passed command-line arg. read shows what you get when "just" reading the rows as they are, read_udt shows how you can register a UDT with your Cluster and have the returned rows nicely cast into your Python class, insert is a sanity check for a single-row (=non-batch) prepared insertion statement (with the UDT properly handled as explained above), and insertb exemplifies use of the previous case within a batch.

    For more information on handling UDT, please check this page: https://docs.datastax.com/en/developer/python-driver/latest/user_defined_types . Note that for unprepared statement you would need a slightly different approach (anyway you probably want to use prepared statement in a production application).

    Looking at your code above, you might want to cast the vehicles into their UDT right within the function you posted, (depending on what exactly you are receiving through the Kafka stream). Just be aware of the fact that a batch will not be executed until you explictly invoke session.execute(batch), which is not shown in your code.

    Another couple of remarks for your awareness:

    1. Do not prepare the statement at each write operation: this is a resource-consuming antipattern. Once you prepare the statement once, instead, keep it cached somewhere (self.prepared_statement is a natural candidate) and then just use it for greater performance
    2. The UDT you structure has no guarantee that num_vehicles == len(speed). If that should be enforced, probably a different model would be better (but then again this depends on your use case)
    3. Take extra care in evaluating whether you really need a batch in this case. Batches in Cassandra are not a way to speed up indiscriminate bulk insertions! For that, you can just issue a number of concurrent simple writes and the drivers will take care of the rest. As long as the statement in a given batch involve different partitions (in your case, different road_id values), then probably a single batch is to be avoided. Read more here: https://batey.info/cassandra-anti-pattern-cassandra-logged.html, https://www.batey.info/cassandra-anti-pattern-misuse-of.html .

    And now the sample code you can start from. (Tested on Cassandra 4.1)

    import sys
    import datetime
    
    from cassandra.cluster import Cluster
    from cassandra.auth import PlainTextAuthProvider
    from cassandra.query import SimpleStatement, BatchStatement
    
    class UdtVehicle():
        def __init__(self, num_vehicles, speed):
            self.num_vehicles = num_vehicles
            self.speed = speed
    
        def __repr__(self):
            return f"UdtVehicle[{self.num_vehicles} vehicles, speeds={', '.join('%.2f' % sp for sp in self.speed)}]"
    
    
    if __name__ == '__main__':
        cluster = Cluster(
            ["CONTACT_POINT"],
            auth_provider=PlainTextAuthProvider(
                "USERNAME",
                "PASSWORD",
            ),
        )
        session = cluster.connect("KEYSPACE_NAME")
    
        cmd = sys.argv[1] if len(sys.argv) > 1 else "read"
        if cmd == "read":
            for r in session.execute("select * from road_traffic;"):
                print(str(r))
                print('-'*20)
                _one_udt = list(r.vehicles.values())[0]
                print(type(_one_udt))
                print(str(_one_udt))
                print('='*20)
        elif cmd == "read_udt":
            cluster.register_user_type("KEYSPACE_NAME", "udtvehicle", UdtVehicle)
            for r in session.execute("select * from road_traffic;"):
                print(str(r))
                print('-'*20)
                _one_udt = list(r.vehicles.values())[0]
                print(type(_one_udt))
                print(str(_one_udt))
                print('='*20)
        elif cmd == "insert":
            insertion_prepared = session.prepare("INSERT INTO road_traffic (road_id, timestamp, radar_id, vehicles) VALUES (?, ?, ?, ?);")
            road_id = 123
            timestamp = datetime.datetime.now()
            radar_id = 456
            vehicles = {
                10: UdtVehicle(
                    num_vehicles=1,
                    speed=[100.1, 100.2],
                ),
                999: UdtVehicle(
                    num_vehicles=100,
                    speed=[],
                ),
                11: UdtVehicle(
                    num_vehicles=3,
                    speed=[0.1, 0.2, 0.3],
                ),
            }
            result = session.execute(insertion_prepared, (road_id, timestamp, radar_id, vehicles))
        elif cmd == "insertb":
            insertion_prepared = session.prepare("INSERT INTO road_traffic (road_id, timestamp, radar_id, vehicles) VALUES (?, ?, ?, ?);")
            batch = BatchStatement()
            # as per best practices, this batch will be a single-partition batch!
            road_id = 100
            t0 = datetime.datetime.now()
            for row_id in range(3):
                timestamp = t0 + datetime.timedelta(hours=row_id)
                radar_id = 1000 + row_id
                vehicles = {
                    (1000+row_id+3): UdtVehicle(
                        num_vehicles=row_id+30,
                        speed=[10.01] * (1+row_id),
                    )
                }
                batch.add(insertion_prepared, (road_id, timestamp, radar_id, vehicles))
            # run the batch
            session.execute(batch)
        else:
            print("Unknown command '%s'" % cmd)