Search code examples

Implementing an extended Hamming code encoder

I'm trying to implement an extended Hamming code encoder after viewing 3Blue1Brown's excellent videos on the subject, and I can't seem to figure out what I'm doing wrong. I have the following code

import math

def get_bitstring(n, bit_length) -> str:
    return format(n, f'0{bit_length}b')

def get_bit(n: int, bit_index: int) -> int:
    return (n >> bit_index) & 1

def set_bit(n: int, bit_index: int, value: int) -> int:
    mask = 1 << bit_index
    n &= ~mask
    if value:
        n |= mask
    return n

def get_on_bits(bits, bit_length):
    # Global parity has index 0, so subtract the index from block_size - 1
    return [bit_length - i - 1 for i, bit in enumerate(bits) if bit]

class HammingCode:
    def __init__(self, bits_per_block: int) -> None:
        self.raw_message: int = 0
        self.encoded_message: int = 0
        self.bits_per_block: int = bits_per_block
        self.parity_bits_per_block: int = int(math.log2(self.bits_per_block)) + 1
        self.data_bits_per_block: int = self.bits_per_block - self.parity_bits_per_block

    def encode(self, message: bytes | bytearray) -> int:
        self.raw_message = int.from_bytes(message, 'big')
        return self.encoded_message

    def copy_bits(self) -> int:
        self.encoded_message = 0
        bits_used = 0
        # Copies bits into encoded message
        for i in range(self.bits_per_block):
            if not self.is_parity_bit(i):
                bit_value = get_bit(self.raw_message, bits_used)
                self.encoded_message = set_bit(self.encoded_message, i, bit_value)
                bits_used += 1
        assert(bits_used == self.data_bits_per_block)
        return self.encoded_message

    def compute_parity(self):
        from functools import reduce
        bits = [int(bit) for bit in self.get_bitstring()]
        parity = reduce(lambda x, y: x ^ y, get_on_bits(bits, self.bits_per_block))
        for i in range(1, self.parity_bits_per_block):
            parity_index = 1 << i
            if parity & parity_index != 0:
                bit_value = 0 if get_bit(self.encoded_message, parity_index) else 1
                self.encoded_message = set_bit(self.encoded_message, i + 1, bit_value)
        bits = [int(bit) for bit in self.get_bitstring()]
        parity = len(get_on_bits(bits, self.bits_per_block))
        self.encoded_message = set_bit(self.encoded_message, 0, 1 if parity & 1 == 1 else 0)
        return self.encoded_message

    def is_parity_bit(self, bit_index: int) -> int:
        assert(bit_index < self.bits_per_block)
        return (bit_index & (bit_index - 1)) == 0

    def get_bitstring(self) -> str:
        return get_bitstring(self.encoded_message, self.bits_per_block)

bytes = b'\x03\x8c'
input = int.from_bytes(bytes)
print(f"Running with {hex(input)} {get_bitstring(input, 16)}")
h = HammingCode(16)
n = h.encode(bytes)
print(f"Encoded message: {get_bitstring(int.from_bytes(bytes), h.data_bits_per_block)} => {get_bitstring(n, h.bits_per_block)}")
print(f"Encoded block in hex is {hex(n)}")

Any tips or ideas on what I'm doing wrong are much appreciated!

I get the correct output for some values, but not others. For example, I get the expected 0xf if I give 0x1 as input, and I get 0x801f with 0x400 as input. If I give 0x38c as input, I should get 0x71d4 (I've found this by doing it on paper), but I instead get 0x70ce.

I've tested and retested each function on its own, such as copy_bits, which inserts a 0 at each index that is a power of 2, and the functions above the HammingCode class, and they all return what I expect. With the 0x38c example, copy_bits returns 0b0111000011000000, which is what I want (it has the 5 parity bits inserted in 0b01110001100).

I believe the problem is in my compute_parity method, but I can't for the life of me figure it out. The logic seems to hold. After xor-ing the indices of the on bits, I want to set the parity bits to ensure that if I call reduce(...) again, it would return 0. For each bit that's set in parity, I want to flip that bit in the encoded message. After which, I xor all the bits and set the 0th bit to ensure even parity.


  • The problem was indeed in compute_parity. I was looping over all the bits again, rather than just setting the parity bits. Even my check to see if I should set the bit was wrong, I should have been checking parity_index against i, not the parity value returned by reduce. Here is the corrected code:

    def compute_parity(self) -> int:
        from functools import reduce
        if self.encoded_message == 0:
            return 0
        bits = [int(bit) for bit in self.get_bitstring()]
        parity = reduce(lambda x, y: x ^ y, get_on_bits(bits, self.bits_per_block))
        parity_bits = list(get_bitstring(parity, self.parity_bits_per_block - 1))
        parity_bits = parity_bits[::-1] # Needs to be reversed to have the LSB at index 0
        for i, parity_bit in enumerate(parity_bits):
            parity_index = 1 << i
            self.encoded_message = set_bit(self.encoded_message, parity_index, int(parity_bit))
        bits = [int(bit) for bit in self.get_bitstring()]
        parity = len(get_on_bits(bits, self.bits_per_block))
        self.encoded_message = set_bit(self.encoded_message, 0, 1 if parity & 1 else 0)
        assert(reduce(lambda x, y: x ^ y, get_on_bits(bits, self.bits_per_block)) == 0)
        return self.encoded_message

    Also had to make sure that the bit value I was passing to set_bit was not a string.