Using the type system to check length of output vs. input list

Suppose a list L, with length n, is interleaved in list J, with length n + 1. We'd like to know, for each element of J, which of its neighbors from L is the greater. The following function takes L as its input, and produces a list K, also of length n + 1, such that the ith element of K is the desired neighbor of the ith element of J.

aux [] prev acc = prev:acc
aux (hd:tl) prev acc = aux tl hd ((max hd prev):acc)

expand row = reverse (aux row 0 [])

I can prove to myself, informally, that the length of the result of this function (which I originally wrote in Ocaml) is one greater than the length of the input. But I hopped over to Haskell (a new language for me) because I got interested in being able to prove via the type system that this invariant holds. With the help of this previous answer, I was able to get as far as the following:

{-# LANGUAGE GADTs, TypeOperators, TypeFamilies #-}

data Z
data S n

type family (:+:) a b :: *
type instance (:+:) Z n = n
type instance (:+:) (S m) n = S (m :+: n)

-- A List of length 'n' holding values of type 'a'
data List a n where
    Nil  :: List a Z
    Cons :: a -> List a m -> List a (S m)

aux :: List a n -> a -> List a m -> List a (n :+: (S m))
aux Nil prev acc = Cons prev acc
aux (Cons hd tl) prev acc = aux tl hd (Cons (max hd prev) acc)

However, the last line produces the following error:

* Could not deduce: (m1 :+: S (S m)) ~ S (m1 :+: S m)
  from the context: n ~ S m1
    bound by a pattern with constructor:
               Cons :: forall a m. a -> List a m -> List a (S m),
             in an equation for `aux'
    at pyramid.hs:23:6-15
  Expected type: List a (n :+: S m)
    Actual type: List a (m1 :+: S (S m))
* In the expression: aux tl hd (Cons (max hd prev) acc)
  In an equation for `aux':
      aux (Cons hd tl) prev acc = aux tl hd (Cons (max hd prev) acc)
* Relevant bindings include
    acc :: List a m (bound at pyramid.hs:23:23)
    tl :: List a m1 (bound at pyramid.hs:23:14)
    aux :: List a n -> a -> List a m -> List a (n :+: S m)
      (bound at pyramid.hs:22:1)

It seems that what I need to do is teach the compiler that (x :+: (S y)) ~ S (x :+: y). Is this possible?

Alternatively, are there better tools for this problem than the type system?


  • First, some imports and language extensions:

    {-# LANGUAGE GADTs, TypeInType, RankNTypes, TypeOperators, TypeFamilies, TypeApplications, AllowAmbiguousTypes #-}
    import Data.Type.Equality

    We now have DataKinds (or TypeInType) which allows us to promote any data to the type level (with its own kind), so the type level naturals really deserve to be defined as a regular data (heck, this is exactly the motivating examples the previous link to the GHC docs give!). Nothing changes with your List type, but (:+:) really should be a closed type family (now over things of kind Nat).

    -- A natural number type (that can be promoted to the type level)
    data Nat = Z | S Nat
    -- A List of length 'n' holding values of type 'a'
    data List a n where
      Nil  :: List a Z
      Cons :: a -> List a m -> List a (S m)
    type family (+) (a :: Nat) (b :: Nat) :: Nat where
      Z + n = n
      S m + n = S (m + n)

    Now, in order to make the proofs work for aux, it is useful to define singleton types for natural numbers.

    -- A singleton type for `Nat`
    data SNat n where
      SZero :: SNat Z
      SSucc :: SNat n -> SNat (S n)
    -- Utility for taking the predecessor of an `SNat`
    sub1 :: SNat (S n) -> SNat n
    sub1 (SSucc x) = x
    -- Find the size of a list
    size :: List a n -> SNat n
    size Nil = SZero
    size (Cons _ xs) = SSucc (size xs)

    Now, we are in shape to start proving some stuff. From Data.Type.Equality, a :~: b represents a proof that a ~ b. We need to prove one simple thing about arithmetic.

    -- Proof that     n + (S m) == S (n + m)
    plusSucc :: SNat n -> SNat m -> (n + S m) :~: S (n + m)
    plusSucc SZero _ = Refl
    plusSucc (SSucc n) m = gcastWith (plusSucc n m) Refl

    Finally, we can use gcastWith to use this proof in aux. Oh and you were missing the Ord a constraint. :)

    aux :: Ord a => List a n -> a -> List a m -> List a (n + S m)
    aux Nil prev acc = Cons prev acc
    aux (Cons hd tl) prev acc = gcastWith (plusSucc (size tl) (SSucc (size acc)))
                                          aux tl hd (Cons (max hd prev) acc)
    -- append to a list
    (|>) :: List a n -> a -> List a (S n)
    Nil |> y = Cons y Nil
    (Cons x xs) |> y = Cons x (xs |> y)
    -- reverse 'List'
    rev :: List a n -> List a n
    rev Nil = Nil
    rev (Cons x xs) = rev xs |> x

    Let me know if this answers your question - getting started with this sort of thing involves a lot of new stuff.