Search code examples
pythonrefactoringgenerator

Refactor a (multi)generator python function


I am looking for a pythonic way to further refactor the function event_stream() below. This simplified and abstracted from a python flask web app I am writing to experiment with python.

The function is a generator, with an endless loop checking a number of objects (currently implemented as dicts) for changes (made elsewhere in the application).

If an object has changed, an event is yielded, which the caller function sse_request() will then use to create a server-side-event.

def event_stream():
    parrot_old = parrot.copy()
    grail_old = grail.copy()
    walk_old = walk.copy()
    while True:
        print("change poller loop")
        gevent.sleep(0.5)

        parrot_changed, parrot_old = check_parrot(parrot_new=parrot, parrot_old=parrot_old)
        if parrot_changed:
            yield parrot_event(parrot)

        grail_changed, grail_old = check_grail(grail_new=grail, grail_old=grail_old)
        if grail_changed:
            yield grail_event(grail)

        walk_changed, walk_old = check_walk(walk_new=walk, walk_old=walk_old)
        if walk_changed:
            yield walk_event(walk)


@app.route('/server_events')
def sse_request():
    return Response(
            event_stream(),
            mimetype='text/event-stream')

While event_stream() is currently short enough to be readable, it breaks the concept of doing "one only, and do that well", because it is tracking changes to three different objects. If I was to add further objects to track (e.g. an "inquisitor", or a "brain") it would become unwieldy.


Solution

  • Short Answer:


    Two steps of refactoring have been applied to the function event_stream(). These are explained in chronological order in the “Long Answer” below, and in summary here:

    The original function had multiple yields: one per “object” whose changes are to be tracked. Adding further objects implied adding further yields.

    • In the first refactoring this was eliminated by looping through a structure storing the objects to be tracked.
    • In the second refactoring changing the “objects” from dicts to instances of real objects. This allowed the storage structure and the loop to become radically simpler. The "old copies" and the functions used to spot changes and create the resulting server-side-events could be moved into the objects.

    The fully refactored code is at the bottom of the "Long Answer".

    Long Answer, Step by Step:


    First Refactor:

    Below I have pasted my first refactoring, inspired by jgr0's answer. His initial suggestion did not work straight away, because it used a dict as a dict key; but keys must be hashable (which dicts are not). It looks like we both hit on using a string as key, and moving the object/dict to an attribute in parallel.

    This solution works where the "object" is either a dict, or a list of dicts (hence use of deepcopy()).

    Each object has a "checker" function to check if it has changed (e.g check_parrot), and an "event" function (e.g. parrot_event) to build the event to be yielded.

    Additional arguments can be configured for the xxx_event() functions via the "args" attribute, which is passed as *args.

    The objects in copies{} are fixed at the time of copying, while those in configured in change_objects{} are references, and thus reflect the latest state of the objects. Comparing the two allows changes to be identified.

    While the event_stream() function is now arguably less readable than my original, I no longer need to touch it to track the changes of further objects. These are added to change_objects{}.

    # dictionary of objects whose changes should be tracked by event_stream()
    change_objects = {
        "parrot": {
            "object": parrot,
            "checker": check_parrot,
            "event": parrot_event,
            "args": (walk['sillyness'],),
                },
        "grail": {
            "object": grail,
            "checker": check_grail,
            "event": grail_event,
        },
        "walk": {
            "object": walk,
            "checker": check_walk,
            "event": walk_event,       
        },
    }
    
    def event_stream(change_objects):
        copies = {}
        for key, value in change_objects.items():
            copies[key] = {"obj_old": deepcopy(value["object"])}  # ensure a true copy, not a reference!
    
        while True:
            print("change poller loop")
            gevent.sleep(0.5)
            for key, value in change_objects.items():
                obj_new = deepcopy(value["object"]) # use same version in check and yield functions
                obj_changed, copies[key]["obj_old"] = value["checker"](obj_new, copies[key]["obj_old"])
                if (obj_changed): # handle additional arguments to the event function
                    if "args" in value:
                        args = value["args"]
                        yield value["event"](obj_new, *args)
                    else:
                        yield value["event"](obj_new)
    
    
    @app.route('/server_events')
    def sse_request():
        return Response(
                event_stream(change_objects),
                mimetype='text/event-stream')
    

    Edit: Second Refactor, using Objects instead of Dicts:

    If I use objects rather than the original dicts:

    • The change-identification and event-building methods can be moved into the objects.
    • The objects can also store the "old" state.

    Everything becomes simpler and even more readable: event_stream() no longer needs a copies{} dict (so it has only one structure to loop over), and the change_objects{} is now a simple list of tracker objects:

    def event_stream(change_objects):
        while True:
            print("change poller loop")
            gevent.sleep(0.5)
            for obj in change_objects:
                if obj.changed():
                    yield obj.sse_event()
    
    @app.route('/server_events')
    def sse_request():
        # List of objects whose changes are tracked by event_stream
        # This list is in sse_request, so each client has a 'private' copy
        change_objects = [
            ParrotTracker(),
            GrailTracker(),
            WalkTracker(),
            ...
            SpamTracker(),
        ]
        return Response(
                event_stream(change_objects),
                mimetype='text/event-stream')
    

    An example tracker class is:

    from data.parrot import parrot
    class ParrotTracker:
    
        def __init__(self):
            self.old = deepcopy(parrot)
            self.new = parrot
    
        def sse_event(self):
            data = self.new.copy()
            data['type'] = 'parrotchange'
            data = json.dumps(data)
            return "data: {}\n\n".format(data)
    
        def truecopy(self, orig):
            return deepcopy(orig) # ensure is a copy, not a reference
    
        def changed(self):
            if self.new != self.old:
                self.old = self.truecopy(self.new)
                return True
            else:
                return False
    

    I think it now smells much better!