Search code examples
haskellgadt

Haskell data type definition depended on GADTs and function output


I would like to have a tensor data structure

data Nat where
    Zero :: Nat
    Succ :: Nat -> Nat

-- | A list of type a and of length n
data ListN a (dim :: Nat) where
    Nil  :: ListN a Zero
    Cons :: a -> ListN a n -> ListN a (Succ n)    

data Tensor a where
        Dense :: ListN a n -> ListN Int Nat -> Tensor a

A tensor is represented by a list of elements and a list of integers representing the dimensions of the tensor. For example [3,4,5,6] in ListN would mean you have 4 dimension were each dimension is 3, 4, 5 and 6 elements long respectively. But now I want n of the first ListN be depended of the product of all integers stored in the second ListN because that's the amount of elements I can have in the first ListN. But how should I do that?


Solution

  • To do this, you'll need a type-level dimension vector for your Tensor type, not just a ListN Int Nat value, so it's probably better to define Tensor with a dims type parameter. You may also find it more convenient to have the dimensions first and the element type second, so something like:

    data ListN (dim :: Nat) a where
        Nil  :: ListN Zero a
        Cons :: a -> ListN n a -> ListN (Succ n) a
    infixr 5 `Cons`
    
    data Tensor (dims :: [Nat]) a where
      Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a
    

    The missing piece here is Product which is a type-level function to multiply the dimensions. It's a little tedious to multiple Peano naturals, but the following works:

    type family Plus m n where
      Plus (Succ m) n = Plus m (Succ n)
      Plus Zero n = n
    
    type family Times m n where
      Times (Succ m) n = Plus n (Times m n)
      Times Zero n = Zero
    
    type family Product (dims) where
      Product '[] = Succ Zero
      Product (m : ns) = Times m (Product ns)
    

    After that, the following type checks. Note that I've made Cons an infixr operator up above to avoid a lot of parentheses:

    t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
    t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)
    

    If the number of elements is wrong, the constraint fails, so the following does not type check:

    t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
    t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)
    

    The full example:

    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE KindSignatures #-}
    {-# LANGUAGE PolyKinds #-}
    {-# LANGUAGE TypeApplications #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE UndecidableInstances #-}
    
    data Nat where
        Zero :: Nat
        Succ :: Nat -> Nat
    
    data ListN (dim :: Nat) a where
        Nil  :: ListN Zero a
        Cons :: a -> ListN n a -> ListN (Succ n) a
    infixr 5 `Cons`
    
    data Tensor (dims :: [Nat]) a where
      Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a
    
    type family Plus m n where
      Plus (Succ m) n = Plus m (Succ n)
      Plus Zero n = n
    
    type family Times m n where
      Times (Succ m) n = Plus n (Times m n)
      Times Zero n = Zero
    
    type family Product (dims) where
      Product '[] = Succ Zero
      Product (m : ns) = Times m (Product ns)
    
    -- type checks
    t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
    t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)
    
    -- won't type check
    t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
    t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)
    

    As noted in the comments, there is a built in non-Peano Nat type that you may find easier to work with. Rewritten to use that, the code would look like this:

    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE KindSignatures #-}
    {-# LANGUAGE NoStarIsType #-}
    {-# LANGUAGE PolyKinds #-}
    {-# LANGUAGE TypeApplications #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE UndecidableInstances #-}
    
    import GHC.TypeLits
    
    data ListN (dim :: Nat) a where
        Nil  :: ListN 0 a
        Cons :: a -> ListN n a -> ListN (1 + n) a
    infixr 5 `Cons`
    
    data Tensor (dims :: [Nat]) a where
      Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a
    
    type family Product dims where
      Product '[] = 1
      Product (m : ns) = m * Product ns
    
    -- type checks
    t1 :: Tensor '[1,2,3] Int
    t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)
    
    -- won't type check
    t2 :: Tensor '[1,2,3] Int
    t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)