Search code examples
cassandracassandra-python-driver

Cannot use BatchQuery in a paged result handler class


Python driver provides with a event/callback approach for large results:

https://datastax.github.io/python-driver/query_paging.html

Also, there is a BatchQuery class to use with ORM and it's quite handy:

https://datastax.github.io/python-driver/cqlengine/batches.html?highlight=batchquery

Now, I need to execute BatchQuery in callback handlers of Paged Result object but script just stucks on iterating on current page.

I guess this is due to impossibility of sharing cassandra sessions between threads, while BatchQuery and the "paged result" approach are using threading to manage event settings and callback calls.

Any idea on how to magically sort this situation out? Below you can find some code:

# paged.py
class PagedQuery:
    """
    Class to manage paged results.
    >>> query = "SELECT * FROM ks.my_table WHERE collectionid=123 AND ttype='collected'"  # define query
    >>> def handler(page):  # define result page handler function
    ...     for t in page:
    ...         print(t)
    >>> pq = PagedQuery(query, handler)  # instantiate a PagedQuery object
    >>> pq.finished_event.wait()  # wait for the PagedQuery to handle all results
    >>> if pq.error:
    ...     raise pq.error
    """
    def __init__(self, query, handler=None):
        session = new_cassandra_session()
        session.row_factory = named_tuple_factory
        statement = SimpleStatement(query, fetch_size=500)
        future = session.execute_async(statement)
        self.count = 0
        self.error = None
        self.finished_event = Event()
        self.query = query
        self.session = session
        self.handler = handler
        self.future = future
        self.future.add_callbacks(
            callback=self.handle_page,
            errback=self.handle_error
        )

    def handle_page(self, page):
        if not self.handler:
            raise RuntimeError('A page handler function was not defined for the query')
        self.handler(page)

        if self.future.has_more_pages:
            self.future.start_fetching_next_page()
        else:
            self.finished_event.set()

    def handle_error(self, exc):
        self.error = exc
        self.finished_event.set()

# main.py
# script using class above
def main():

    query = 'SELECT * FROM ks.my_table WHERE collectionid=10 AND ttype=\'collected\''

    def handle_page(page):

        b = BatchQuery(batch_type=BatchType.Unlogged)
        for obj in page:
            process(obj)  # some updates on obj...
            obj.batch(b).save()

        b.execute()

    pq = PagedQuery(query, handle_page)
    pq.finished_event.wait()

    if not pq.count:
        print('Empty queryset. Please, check parameters')

if __name__ == '__main__':
    main()

Solution

  • Due the fact you cannot execute queries in the event loop of ResponseFuture, you can just iterate and send objects to queues. We do have kafka queues to persist objects but in this case a thread safe Python Queue works well.

    import sys
    import datetime
    import queue
    import threading
    import logging
    
    from cassandra.connection import Event
    from cassandra.cluster import Cluster, default_lbp_factory, NoHostAvailable
    from cassandra.cqlengine.connection import (Connection, DEFAULT_CONNECTION, _connections)
    from cassandra.query import named_tuple_factory, PreparedStatement, SimpleStatement
    from cassandra.auth import PlainTextAuthProvider
    from cassandra.util import OrderedMapSerializedKey
    from cassandra.cqlengine.query import BatchQuery
    from smfrcore.models.cassandra import Tweet
    
    STOP_QUEUE = object()
    logging.basicConfig(level=logging.DEBUG, format='[%(levelname)s] (%(threadName)-9s) %(message)s',)
    
    
    def new_cassandra_session():
        retries = 5
        _cassandra_user = 'user'
        _cassandra_password = 'xxxx'
        while retries >= 0:
            try:
                cluster_kwargs = {'compression': True,
                              'load_balancing_policy': default_lbp_factory(),
                              'executor_threads': 10,
                              'idle_heartbeat_interval': 10,
                              'idle_heartbeat_timeout': 30,
                              'auth_provider': PlainTextAuthProvider(username=_cassandra_user, password=_cassandra_password)}
    
                cassandra_cluster = Cluster(**cluster_kwargs)
                cassandra_session = cassandra_cluster.connect()
                cassandra_session.default_timeout = None
                cassandra_session.default_fetch_size = 500
                cassandra_session.row_factory = named_tuple_factory
                cassandra_default_connection = Connection.from_session(DEFAULT_CONNECTION, session=cassandra_session)
                _connections[DEFAULT_CONNECTION] = cassandra_default_connection
                _connections[str(cassandra_session)] = cassandra_default_connection
            except (NoHostAvailable, Exception) as e:
                print('Cassandra host not available yet...sleeping 10 secs: ', str(e))
                retries -= 1
                time.sleep(10)
            else:
                return cassandra_session
    
    
    class PagedQuery:
        """
        Class to manage paged results.
        >>> query = "SELECT * FROM ks.my_table WHERE collectionid=123 AND ttype='collected'"  # define query
        >>> def handler(page):  # define result page handler function
        ...     for t in page:
        ...         print(t)
        >>> pq = PagedQuery(query, handler)  # instantiate a PagedQuery object
        >>> pq.finished_event.wait()  # wait for the PagedQuery to handle all results
        >>> if pq.error:
        ...     raise pq.error
        """
        def __init__(self, query, handler=None):
            session = new_cassandra_session()
            session.row_factory = named_tuple_factory
            statement = SimpleStatement(query, fetch_size=500)
            future = session.execute_async(statement)
            self.count = 0
            self.error = None
            self.finished_event = Event()
            self.query = query
            self.session = session
            self.handler = handler
            self.future = future
            self.future.add_callbacks(
                callback=self.handle_page,
                errback=self.handle_error
            )
    
        def handle_page(self, page):
            if not self.handler:
                raise RuntimeError('A page handler function was not defined for the query')
            self.handler(page)
    
            if self.future.has_more_pages:
                self.future.start_fetching_next_page()
            else:
                self.finished_event.set()
    
        def handle_error(self, exc):
            self.error = exc
            self.finished_event.set()
    
    
    
    def main():
    
        query = 'SELECT * FROM ks.my_table WHERE collectionid=1 AND ttype=\'collected\''
    
        q = queue.Queue()
        threads = []
    
        def worker():
            nonlocal q
            local_counter = 0
            b = BatchQuery(batch_type=BatchType.Unlogged)
            while True:
                tweet = q.get()
    
                if tweet is STOP_QUEUE:
                    b.execute()
    
                    logging.info(' >>>>>>>>>>>>>>>> Executed last batch for this worker!!!!')
                    break
    
                tweet.batch(b).save()
                local_counter += 1
                if not (local_counter % 500):
                    b.execute()
                    logging.info('>>>>>>>>>>>>>>>> Batch executed in this worker: geotagged so far:', str(local_counter))
                    b = BatchQuery(batch_type=BatchType.Unlogged)  # reset batch
    
                q.task_done()
    
        def handle_page(page):
    
            for obj in page:
                process(obj)  # some updates on obj...
                q.put(obj)
    
        pq = PagedQuery(query, handle_page)
        pq.finished_event.wait()
        # block until all tasks are done
        q.join()
    
        # stop workers by sending sentinel value (None)
        for i in range(4):
            q.put(STOP_QUEUE)
    
        for t in threads:
            t.join()
    
        if pq.error:
            raise pq.error
    
        if not pq.count:
            print('Empty queryset. Please, check parameters')
    
    if __name__ == '__main__':
        sys.exit(main())