Search code examples
haskellbitbytestring

Problem with bit swapping in Haskell


As part of a school project I'm implementing some crypthographic algorithms in Haskell. As you probably know this involves quite a lot of low level bit fiddling. Now I am stuck on one particular sub routine which causes me a headache. The routine, which is a permutation on 256 bits, works as follows:

Input: a 256 bit block.
Then all the even numbered bits (0,2,...) in the input block are taken to be the first 128 bits in the output block. While the odd numbered bits are taken to be the 128 last bits in the output block. More specifically, the formula for the i'th bit in the output is given as (ai is the i'th bit in the input block, and b is the output):

bi = a2i

bi+2d-1 = a2i + 1

for i from 0 to 2d-1-1, d = 8.

As a toy example, assume we used a reduced version of the routine which worked with 16 bit blocks instead of 256 bits. Then the following bitstring would be permuted as follows:

1010 1010 1010 1010 -> 1111 1111 0000 0000

I have not been able to come up with a clean implementation for this function. In particular I have been trying with a ByteString -> ByteString signature, but that sort of forces me to work on a Word8 kind of granularity. But each byte in the output bytestring is a function of bits in all the other bytes, which requires some really messy operations.

I will be really grateful for any kind of hint or advice on how to approach this problem.


Solution

  • If you want an efficient implementation, I don't think you can avoid working with bytes. Here is an example solution. It assumes that there is always an even number of bytes in the ByteString. I'm not very familiar with unboxing or strictness tweaking, but I think these would be necessary if you want to be very efficient.

    import Data.ByteString (pack, unpack, ByteString)
    import Data.Bits
    import Data.Word
    
    -- the main attraction
    packString :: ByteString -> ByteString
    packString = pack . packWords . unpack
    
    -- main attraction equivalent, in [Word8]
    packWords :: [Word8] -> [Word8]
    packWords ws = evenPacked ++ unevenPacked
        where evenBits = map packEven ws
              unevenBits = map packUneven ws
              evenPacked = consumePairs packNibbles evenBits
              unevenPacked = consumePairs packNibbles unevenBits
    
    -- combines 2 low nibbles (first 4 bytes) into a (high nibble, low nibble) word
    -- assumes that only the low nibble of both arguments can be non-zero. 
    packNibbles :: Word8 -> Word8 -> Word8
    packNibbles w1 w2 = (shiftL w1 4) .|. w2 
    
    packEven w = packBits w [0, 2, 4, 6]
    
    packUneven w = packBits w [1, 3, 5, 7]
    
    -- packBits 254 [0, 2, 4, 6] = 14 
    -- packBits 254 [1, 3, 5, 7] = 15
    packBits :: Word8 -> [Int] -> Word8
    packBits w is = foldr (.|.) 0 $ map (packBit w) is
    
    -- packBit 255 0 = 1
    -- packBit 255 1 = 1
    -- packBit 255 2 = 2
    -- packBit 255 3 = 2
    -- packBit 255 4 = 4
    -- packBit 255 5 = 4
    -- packBit 255 6 = 8
    -- packBit 255 7 = 8
    packBit :: Word8 -> Int -> Word8
    packBit w i = shiftR (w .&. 2^i) ((i `div` 2) + (i `mod` 2))
    
    -- sort of like map, but halves the list in size by consuming two elements. 
    -- Is there a clearer way to write this with built-in function?
    consumePairs :: (a -> a -> b) -> [a] -> [b]
    consumePairs f (x : x' : xs) = f x x' : consumePairs f xs
    consumePairs _ [] = []
    consumePairs _ _ = error "list must contain even number of elements"