Search code examples
pythonmultithreadingsocketswebsocket

how to properly frame websocket messages in python socket library


I'm building a websocket server in python using socket, it's embedded python that I can't install any packages for. I've gotten as far as handling the handshake and establish a connection. I can get data being sent back and forth from my server and client (a React App), but some of the payloads are too large, batching them up worked but then it was too slow. So I've compressed the data using zlib. Now the issue is it's saying Invalid Websocket Frame. I have tried to condense the code as much as possible to give a server that will run and demonstrate the problem.

I will be honest, I started writing this myself but have used AI for some parts of it, particularly the websocket handshake and framing the message, which I don't fully understand (hence the question), and AI can only get you so far.

So here is my server code, this runs in python 3.9 - 3.11, haven't tried other versions.

import socket
import struct
import base64
import hashlib
import zlib
import logging
import json
from threading import Thread

class WebSocketServer(Thread):
    def __init__(self):
        Thread.__init__(self)
        self.connection = None
        self.logger = logging.getLogger('WebSocketServer')

    def run(self):
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.bind(('', 5558))
            self.logger.info("Server started, waiting for connections...")
            sock.listen(1)
            while True:
                connection, _ = sock.accept()
                if connection:
                    self.connection = connection
                    self.logger.info("Client connected")
                    Thread(target=self.handle_connection, args=[connection]).start()
        except Exception as e:
            self.logger.error(f'Run error: {e}')
        finally:
            if self.connection:
                self.connection.close()
                self.logger.info('Server socket closed')

    def handle_connection(self, connection):
        try:
            if self.perform_handshake(connection):
                while True:
                    msg = self.receive_message(connection)
                    if msg:
                        self.logger.info(f'Received message: {msg}')
                        # Echo messages back
                        self.send_message(json.dumps(msg))
                    else:
                        break
        except Exception as e:
            self.logger.error(f'Connection error: {e}')
        finally:
            connection.close()
            self.logger.info('Connection closed')

    def perform_handshake(self, connection):
        try:
            self.logger.info("Performing handshake...")
            request = connection.recv(1024).decode('utf-8')
            self.logger.info(f"Handshake request: {request}")

            headers = self.parse_headers(request)
            websocket_key = headers['Sec-WebSocket-Key']
            websocket_accept = self.generate_accept_key(websocket_key)

            response = (
                'HTTP/1.1 101 Switching Protocols\r\n'
                'Upgrade: websocket\r\n'
                'Connection: Upgrade\r\n'
                f'Sec-WebSocket-Accept: {websocket_accept}\r\n\r\n'
            )

            connection.send(response.encode('utf-8'))
            self.logger.info("Handshake response sent")
            return True
        except Exception as e:
            self.logger.error(f'Handshake error: {e}')
            return False

    def parse_headers(self, request):
        headers = {}
        lines = request.split('\r\n')
        for line in lines[1:]:
            if line:
                key, value = line.split(': ', 1)
                headers[key] = value
        return headers

    def generate_accept_key(self, websocket_key):
        magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
        accept_key = base64.b64encode(hashlib.sha1((websocket_key + magic_string).encode()).digest()).decode('utf-8')
        return accept_key

    def receive_message(self, connection):
        try:
            data = connection.recv(1024)
            if not data:
                return None

            byte1, byte2 = struct.unpack('BB', data[:2])
            fin = byte1 & 0b10000000
            opcode = byte1 & 0b00001111
            masked = byte2 & 0b10000000
            payload_length = byte2 & 0b01111111

            if masked != 0b10000000:
                self.logger.error('Client data must be masked')
                return None

            if payload_length == 126:
                extended_payload_length = data[2:4]
                payload_length = int.from_bytes(extended_payload_length, byteorder='big')
                masking_key = data[4:8]
                payload_data = data[8:]
            elif payload_length == 127:
                extended_payload_length = data[2:10]
                payload_length = int.from_bytes(extended_payload_length, byteorder='big')
                masking_key = data[10:14]
                payload_data = data[14:]
            else:
                masking_key = data[2:6]
                payload_data = data[6:]

            decoded_bytes = bytearray()
            for i in range(payload_length):
                decoded_bytes.append(payload_data[i] ^ masking_key[i % 4])

            if opcode == 0x1:  # Text frame
                return decoded_bytes.decode('utf-8')
            elif opcode == 0x8:  # Connection close frame
                self.logger.info('Connection closed by client')
                return None
            else:
                self.logger.error(f'Unsupported frame type: {opcode}')
                return None
        except Exception as e:
            self.logger.error(f'Error receiving message: {e}')
            return None

    def send_message(self, message):
        try:
            if self.connection and isinstance(message, str):
                # Compress the message using zlib
                compressed_message = zlib.compress(message.encode('utf-8'))

                # Determine chunk size based on network conditions
                max_chunk_size = 1024  # Adjust as needed

                # Split the compressed message into smaller chunks
                chunks = [compressed_message[i:i+max_chunk_size] for i in range(0, len(compressed_message), max_chunk_size)]

                for chunk in chunks:
                    frame = bytearray()
                    frame.append(0b10000001)  # Text frame opcode

                    length = len(chunk)
                    if length <= 125:
                        frame.append(length)
                    elif length <= 65535:
                        frame.append(126)
                        frame.extend(struct.pack('!H', length))
                    else:
                        frame.append(127)
                        frame.extend(struct.pack('!Q', length))

                    # Append the chunk to the frame
                    frame.extend(chunk)

                    # Send the framed chunk
                    self.connection.sendall(frame)
            else:
                self.logger.error("Connection closed or invalid message")
        except Exception as e:
            self.logger.error(f'Error sending message: {e}')
            
            
# Configure logging
logging.basicConfig(level=logging.INFO)

# Create an instance of WebSocketServer
server = WebSocketServer()

# Start the server
server.start()

And to test it I'm just using wscat from the terminal

wscat -c ws://127.0.0.1:5558

Then I type any message, and the response I get is:

error: Invalid WebSocket frame: invalid UTF-8 sequence

  1. Why am I seeing this error?
  2. How should I be framing these messages?
  3. Is there a more efficient way of doing this?

For context, the payloads I'm sending is an array of arrays of integers, the data is MIDI SysEx, so each array starts with a 0xF0 and ends with an 0xF7 byte. Because of this the messages are getting sent really quickly which is what was causing the problem. The first message is usually going to be quite large and get batched up, so could be thousands of SysEx arrays.

I have sent this data over similar web sockets before so I know it is possible, but I've never had to write them from scratch before.

Any help here would be greatly appreciated.


Solution

  • I have found the solution to my problem so will answer it here.

    Thanks to both @SteffenUllrich @MarkTolonen for their guidance above.

    As @SteffenUllrich stated my client was expecting text data and I am sending compressed binary data, so I needed to use the binary frame opcode instead of text.

    This line inside of the send_message function

    frame.append(0b10000010)  # Binary frame opcode
    

    Instead of

    frame.append(0b10000001)  # Text frame opcode
    

    This allows the data to come through into wscat and works fine, but I wasn't seeing it in the browser. So I needed to set the following in my React App to allow it to accept binary data.

    websocket.binaryType = 'arraybuffer';
    

    Additional reading that I found that helped me understand what's going on can be found here