Search code examples
haskellaesrijndael

What's wrong with this Haskell AES key expansion?


Following steps describing key expansion in Jeff Moser's popular tutorial, I've written this code for key expansion. Here is the entire file (which also computes the S-Box), so folks can compile and try it.

{-# LANGUAGE NoMonomorphismRestriction #-}

import Control.Applicative (liftA2)
import Data.Bits (xor, shiftL, shiftR, (.|.), (.&.))
import Data.List (transpose, sortBy)
import Data.Ord (comparing)
import Data.Word (Word8)
import Numeric (showHex)

keys = f 16 $ f 8 $ f 4 $ f 2 $ f 1 key
 where
  f w n = xpndC . xpndB . xpndA $ xpndD w n

xpndC   :: [[Word8]] -> [[Word8]]
xpndC ws = transpose [head ws, b, zipWith xor b c, last ws]
 where
  (b,c) = (ws !! 1, ws !! 2)

xpndB   :: [[Word8]] -> [[Word8]]
xpndB ws = a : zipWith xor a b : drop 2 ws
 where
  (a,b) = (head ws, ws !! 1)

xpndA   :: [[Word8]] -> [[Word8]]
xpndA ws = zipWith xor a d : tail ws
 where
  (a,d) = (head ws, last ws)

xpndD rc ws = take 3 tW ++ [w']
 where
  w' = zipWith xor (map sub w) [rc, 0, 0, 0]
  tW = transpose ws
  w  = take 4 $ tail $ cycle $ last tW

--------------------------------------------------------------
sub w = get sbox (fromIntegral lo) $ fromIntegral hi
 where
  (hi, lo) = nibs w

get wss x y = (wss !! y) !! x

print' = print . w128 . concat . transpose
 where
  w128 = concatMap (f . (`showHex` ""))
  f w  = (length w < 2) ? (' ':'0':w, ' ':w)

grid _ [] = []
grid n xs = take n xs : grid n (drop n xs)

nibs w    = (shiftR (w .&. 0xF0) 4, w .&. 0x0F)
(⊕)       = xor
p ? (a,b) = if p then a else b; infix 2 ?

---------------------------------------------------
sbox :: [[Word8]]
sbox = grid 16 $ map snd $ sortBy (comparing fst) $ sbx 1 1 []

sbx :: Word8 -> Word8 -> [(Word8, Word8)] -> [(Word8, Word8)]
sbx p q ws
  | length ws == 255 = (0, 0x63) : ws
  | otherwise = sbx p' r $ (p', xf ⊕ 0x63) : ws
 where
  p' = p  ⊕  shiftL p  1 ⊕  ((p .&. 0x80 /= 0) ? (0x1B, 0))
  q1 = foldl (liftA2 (.) xor shiftL) q [1, 2, 4]
  r  = q1 ⊕  ((q1 .&. 0x80 /= 0) ? (0x09, 0))
  xf = r  ⊕  rotl8 r 1 ⊕  rotl8 r 2 ⊕  rotl8 r 3 ⊕  rotl8 r 4

rotl8 w n = (w `shiftL` n) .|. (w `shiftR` (8 - n))

key = [[0,0,0,0],
       [0,0,0,0],
       [0,0,0,0],
       [0,0,0,0]] :: [[Word8]]

When I test this code against the all-zero test key, it matches the published expectation up to the fourth iteration: ee 06 da 7b 87 6a 15 81 75 9e 42 b2 7e 91 ee 2b.

But when I try the next iteration: keys = f 16 $ f 8 $ f 4 $ f 2 $ f 1, the last 32 bits of the result are wrong: 7f 2e 2b 88 f8 44 3e 09 8d da 7c bb 91 28 f1 f3.

The same behavior - last 32 bits wrong - happens when I use all 0xFF for the initial key. And in subsequent iterations, all the bits are wrong.

If I use the test vector 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f, things go wrong much faster - I start getting wrong bits on the second iteration.

What's going on here? I notice Mr. Moser wrote in part 2b : 4 to xor with the first column of the previous round key - but there is no previous round for the initial key, so this confused me. Is this what I've done wrong?

For reference, here are the test vectors.


Solution

  • You're missing a step.

    xpndC ws = transpose [head ws, b, zipWith xor b c, last ws]
    

    the fourth column should be the xor of the previous fourth column (which you've thrown away in the first pass) and the new third column.

    The fact that xor x x = 0 somehow contributed to this mistake being noticeable only at the fifth iteration.


    Minor stylistic comments

    Pattern matching on a fixed structure is less awkward than (!!).

    xpndC :: [[Word8]] -> [[Word8]]
    xpndC [a,b,c,d] = [a, b, zipWith xor b c, d]
    

    Also note that steps 2b4 and 3 are actually a scan. Roughly, it ends up looking like this (with the name schedule_core borrowed from your last link):

    new = tail $ scanl (zipWith xor) (schedule_core (last old)) old
    

    Edit: Fix

    The solution is essentially to not throw away the last column. You can, as a quick fix, inject it in an additional pass this way:

    keys = f 16 $ f 8 $ f 4 $ f 2 $ f 1 key
     where
      f w n = xpndE (transpose n) . xpndC . xpndB . xpndA $ xpndD w n
    
    xpndE n [a,b,c,_] = transpose [a,b,c,zipWith xor c (last n)]
    
    xpndC = (...) {- remove transpose here -}
    

    The xpnd* functions may be a bit too fine-grained, once you realize that the list is quite small. I would also factor transpose out, if you want to keep it at all.

    keys = transpose $ f 16 $ f 8 $ f 4 $ f 2 $ f 1 $ transpose key
      where
        f rc [a, b, c, d] =
          let e = schedule rc d
              a' = zipWith xor a e
              b' = zipWith xor b a'
              c' = zipWith xor c b'
              d' = zipWith xor c c'
          in [a', b', c', d']  -- Here is where one may recognize `scanl` or a fold.
    

    As for schedule, it's the function that takes the last column (d above, last tW below) and scrambles it (e above, w' below). You can extract it from your definition of xpndD:

    xpndD rc ws = take 3 tW ++ [w']
     where
      w' = zipWith xor (map sub w) [rc, 0, 0, 0]
      tW = transpose ws
      w  = take 4 $ tail $ cycle $ last tW
    

    We get (modulo a purely cosmetic rewriting take 4 $ tail $ cycle d = tail d ++ [head d]):

    schedule rc d = zipWith xor (map sub $ tail d ++ [head d]) [rc, 0, 0, 0]