I would like to have a tensor data structure
data Nat where
Zero :: Nat
Succ :: Nat -> Nat
-- | A list of type a and of length n
data ListN a (dim :: Nat) where
Nil :: ListN a Zero
Cons :: a -> ListN a n -> ListN a (Succ n)
data Tensor a where
Dense :: ListN a n -> ListN Int Nat -> Tensor a
A tensor is represented by a list of elements and a list of integers representing the dimensions of the tensor. For example [3,4,5,6] in ListN would mean you have 4 dimension were each dimension is 3, 4, 5 and 6 elements long respectively. But now I want n of the first ListN be depended of the product of all integers stored in the second ListN because that's the amount of elements I can have in the first ListN. But how should I do that?
To do this, you'll need a type-level dimension vector for your Tensor
type, not just a ListN Int Nat
value, so it's probably better to define Tensor
with a dims
type parameter. You may also find it more convenient to have the dimensions first and the element type second, so something like:
data ListN (dim :: Nat) a where
Nil :: ListN Zero a
Cons :: a -> ListN n a -> ListN (Succ n) a
infixr 5 `Cons`
data Tensor (dims :: [Nat]) a where
Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a
The missing piece here is Product
which is a type-level function to multiply the dimensions. It's a little tedious to multiple Peano naturals, but the following works:
type family Plus m n where
Plus (Succ m) n = Plus m (Succ n)
Plus Zero n = n
type family Times m n where
Times (Succ m) n = Plus n (Times m n)
Times Zero n = Zero
type family Product (dims) where
Product '[] = Succ Zero
Product (m : ns) = Times m (Product ns)
After that, the following type checks. Note that I've made Cons
an infixr
operator up above to avoid a lot of parentheses:
t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)
If the number of elements is wrong, the constraint fails, so the following does not type check:
t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)
The full example:
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
data Nat where
Zero :: Nat
Succ :: Nat -> Nat
data ListN (dim :: Nat) a where
Nil :: ListN Zero a
Cons :: a -> ListN n a -> ListN (Succ n) a
infixr 5 `Cons`
data Tensor (dims :: [Nat]) a where
Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a
type family Plus m n where
Plus (Succ m) n = Plus m (Succ n)
Plus Zero n = n
type family Times m n where
Times (Succ m) n = Plus n (Times m n)
Times Zero n = Zero
type family Product (dims) where
Product '[] = Succ Zero
Product (m : ns) = Times m (Product ns)
-- type checks
t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)
-- won't type check
t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)
As noted in the comments, there is a built in non-Peano Nat
type that you may find easier to work with. Rewritten to use that, the code would look like this:
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
import GHC.TypeLits
data ListN (dim :: Nat) a where
Nil :: ListN 0 a
Cons :: a -> ListN n a -> ListN (1 + n) a
infixr 5 `Cons`
data Tensor (dims :: [Nat]) a where
Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a
type family Product dims where
Product '[] = 1
Product (m : ns) = m * Product ns
-- type checks
t1 :: Tensor '[1,2,3] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)
-- won't type check
t2 :: Tensor '[1,2,3] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)