Search code examples
pythonwebsockettwistedautobahn

Trial unittests using Autobahn WebSocket


I'm trying to write unittests for my application that uses Autobahn.

I want to test my controllers which gets received data from protocol, parses it and reacts to it.

But when my test comes to a point when protocol should be disconnected (self.sendClose) then it raises error

exceptions.AttributeError: 'MyProtocol' object has no attribute 'state'.

I was trying to makeConnection using proto_helpers.StringTransport but then I have errors too

exceptions.AttributeError: StringTransport instance has no attribute 'setTcpNoDelay'`

I'm using trial and I don't want to run dummy server/client for testing purposes only, because it's not recommended.

How should I write my tests so I can test functions that sends data, read data, disconnects etc. using fake connection and trial ?


Solution

  • It is difficult to say exactly what is going on without having a peek at MyProtocol class. The problem sounds a lot like it is caused by the fact that you are directly messing round with low level functions and therefore also the state attribute of WebSocket class, which is, well, a representation of the internal state of the WebSocket connection.

    According to the autobahn reference doc, the APIs from the WebSicketProtocol that you could directly use and override are:

    • onOpen
    • onMessage
    • onClose
    • sendMessage
    • sendClose

    Your approach of using the StringTransport to test your protocol is not ideal. The problem lays in the fact that MyProtocol is a tiny layer on top of the WebSocketProtocol framework provided by autobahn which, for better or worse, hides the details about managing the connection, the transport and the internal protocol state.

    If you think about it, you want to test your stuff, not WebSocketProtocol and therefore if you do not want to embed a dummy server or client, your best bet is to test directly the methods that MyProtocol overrides.

    An example of what I am saying is the following

    class MyPublisher(object):
        cbk=None
    
        def publish(self, msg):
            if self.cbk:
                self.cbk(msg)
    
    class MyProtocol(WebSocketServerProtocol):
    
        def __init__(self, publisher):
            WebSocketServerProtocol.__init__(self)
            #Defining callback for publisher
            publisher.cbk = self.sendMessage
    
        def onMessage(self, msg, binary)
            #Stupid echo
            self.sendMessage(msg)
    
    class NotificationTest(unittest.TestCase):    
    
        class MyProtocolFactory(WebSocketServerFactory):
            def __init__(self, publisher):
                WebSocketServerFactory.__init__(self, "ws://127.0.0.1:8081")
                self.publisher = publisher
                self.openHandshakeTimeout = None
    
            def buildProtocol(self, addr):
                protocol =  MyProtocol(self.listener)
                protocol.factory = self
                protocol.websocket_version = 13 #Hybi version 13 is supported by pretty much everyone (apart from IE <8 and android browsers)
                return protocol
    
        def setUp(self):
            publisher = task.LoopingCall(self.send_stuff, "Hi there")        
            factory = NotificationTest.MyProtocolFactory(listener)
            protocol = factory.buildProtocol(None)
            transport = proto_helpers.StringTransport()
            def play_dumb(*args): pass
            setattr(transport, "setTcpNoDelay", play_dumb)
            protocol.makeConnection(transport)
            self.protocol, self.transport, self.publisher, self.fingerprint_handler =  protocol, transport, publisher, fingerprint_handler
    
        def test_onMessage(self):
            #Following 2 lines are the problematic part. Here you are manipulating explicitly a hidden state which your implementation should not be concerned with!
            self.protocol.state = WebSocketProtocol.STATE_OPEN
            self.protocol.websocket_version = 13
            self.protocol.onMessage("Whatever")
            self.assertEqual(self.transport.value()[2:], 'Whatever')
    
        def test_push(self):              
            #Following 2 lines are the problematic part. Here you are manipulating explicitly a hidden state which your implementation should not be concerned with!
            self.protocol.state = WebSocketProtocol.STATE_OPEN
            self.protocol.websocket_version = 13
            self.publisher.publish("Hi there")
            self.assertEqual(self.transport.value()[2:], 'Hi There')
    

    As you might have noticed, using the StringTransport here is very cumbersome. You must have knowledge of the underline framework and bypass its state management, something you don't really want to do. Unfortunately autobahn does not provide a ready-to-use test object that would permit easy state manipulation and therefore my suggestion of using dummy servers and clients is still valid


    Testing your server WITH network

    The test provided shows how you can test server push, asserting that what your are getting is what you expect, and using also a hook on how to determine when to finish.

    The server protocol

    from twisted.trial.unittest import TestCase as TrialTest
    from autobahn.websocket import WebSocketServerProtocol, WebSocketServerFactory, WebSocketClientProtocol, WebSocketClientFactory, connectWS, listenWS
    from twisted.internet.defer import Deferred
    from twisted.internet import task 
    
    START="START"            
    
    class TestServerProtocol(WebSocketServerProtocol):
    
        def __init__(self):
            #The publisher task simulates an event that triggers a message push
            self.publisher = task.LoopingCall(self.send_stuff, "Hi there")
    
        def send_stuff(self, msg):
            #this method sends a message to the client
            self.sendMessage(msg)
    
        def _on_start(self):
            #here we trigger the task to execute every second
            self.publisher.start(1.0)
    
        def onMessage(self, message, binary):
            #According to this stupid protocol, the server starts sending stuff when the client sends a "START" message
            #You can plug other commands in here
            {
               START : self._on_start
               #Put other keys here
            }[message]()
    
        def onClose(self, wasClean, code, reason):
            #After closing the connection, we tell the task to stop sending messages
            self.publisher.stop()
    

    The client protocol and factory

    Next class is the client protocol. It basically tells the server to start pushing messages. It calls the close_condition on them to see if it is time to close the connection and as a last thing, it calls the assertion function on the messages it received to see if the test was successful or not

    class TestClientProtocol(WebSocketClientProtocol):
        def __init__(self, assertion, close_condition, timeout, *args, **kwargs):
            self.assertion = assertion
            self.close_condition = close_condition
            self._received_msgs = [] 
            from twisted.internet import reactor
            #This is a way to set a timeout for your test 
            #in case you never meet the conditions dictated by close_condition
            self.damocle_sword = reactor.callLater(timeout, self.sendClose)
    
        def onOpen(self):
            #After the connection has been established, 
            #you can tell the server to send its stuff
            self.sendMessage(START)
    
        def onMessage(self, msg, binary):
            #Here you get the messages pushed from the server
            self._received_msgs.append(msg)
            #If it is time to close the connection
            if self.close_condition(msg):
                self.damocle_sword.cancel()
                self.sendClose()
    
        def onClose(self, wasClean, code, reason):
            #Now it is the right time to check our test assertions
            self.assertion.callback(self._received_msgs)
    
    class TestClientProtocolFactory(WebSocketClientFactory):
        def __init__(self, assertion, close_condition, timeout, **kwargs):
            WebSocketClientFactory.__init__(self, **kwargs)
            self.assertion = assertion
            self.close_condition = close_condition
            self.timeout = timeout
            #This parameter needs to be forced to None to not leave the reactor dirty
            self.openHandshakeTimeout = None
    
        def buildProtocol(self, addr):
            protocol = TestClientProtocol(self.assertion, self.close_condition, self.timeout)
            protocol.factory = self
            return protocol
    

    The trial based test

    class WebSocketTest(TrialTest):
    
        def setUp(self):
            port = 8088
            factory = WebSocketServerFactory("ws://localhost:{}".format(port))
            factory.protocol = TestServerProtocol
            self.listening_port = listenWS(factory)
            self.factory, self.port = factory, port
    
        def tearDown(self):
            #cleaning up stuff otherwise the reactor complains
            self.listening_port.stopListening()
    
        def test_message_reception(self): 
            #This is the test assertion, we are testing that the messages received were 3
            def assertion(msgs):
                self.assertEquals(len(msgs), 3)
    
            #This class says when the connection with the server should be finalized. 
            #In this case the condition to close the connectionis for the client to get 3 messages
            class CommunicationHandler(object):
                msg_count = 0
    
                def close_condition(self, msg):
                    self.msg_count += 1
                    return self.msg_count == 3
    
            d = Deferred()
            d.addCallback(assertion)
            #Here we create the client...
            client_factory = TestClientProtocolFactory(d, CommunicationHandler().close_condition, 5, url="ws://localhost:{}".format(self.port))
            #...and we connect it to the server
            connectWS(client_factory)
            #returning the assertion as a deferred purely for demonstration
            return d
    

    This is obviously just an example, but as you can see I did not have to mess around with makeConnection or any transport explicitly