Search code examples
haskelltype-safetygadtdata-kinds

Haskell GADTs - making a type-safe Tensor types for Riemannian geometry


I want to make a type safe implementation of Tensor calculus in Haskell using GADT's, so the rules are:

  1. Tensors are n-dimentional metrices with indecies that can be 'upstairs' or 'downstairs' eg: enter image description here - is a Tensor with no indecies (a scalar), enter image description here is a Tensor with one 'upstairs' index, enter image description here is a tensor with a bunch of 'upstairs' and 'downstairs' indecies
  2. You can ADD tensor of the same type, meaning they have the same indecies signature. the 0th index of the first tensor is of the same type(upstairs or downstairs) as the 0th index of the second tensor and so on...

    enter image description here ~~~~ OK

    enter image description here ~~~~ NOT OK

  3. You can MULTIPLY tensors and get bigger tensors, with the indecies concatenated: enter image description here

So I want that the type-checker of Haskell wouldn't allow me to write code that doesn't follow those rules, It wouldn't compile otherwise.

Here is my attempt using GADTs:

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}

data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)

plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)

data Tensor a = (a ~ Zero) => Scalar Double | 
                forall b. (a ~ Up b) => Cov (Direction -> Tensor b) |
                forall b. (a ~ Down b) => Con (Direction -> Tensor b) 

add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))

mul :: Tensor a -> Tensor b -> Tensor (plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))

But i'm getting:

Couldn't match type 'Down with `plus ('Down b1)'                                                                                                                                                                                                    
    Expected type: Tensor (plus a b)                                                                                                                                                                                                                    
      Actual type: Tensor ('Down b)                                                                                                                                                                                                                     
    Relevant bindings include                                                                                                                                                                                                                           
      f :: Direction -> Tensor b1 (bound at main.hs:28:10)                                                                                                                                                                                              
      mul :: Tensor a -> Tensor b -> Tensor (plus a b)                                                                                                                                                                                                  
        (bound at main.hs:24:1)                                                                                                                                                                                                                         
    In the expression: (Con (\ d -> mul (f d) y))                                                                                                                                                                                                       
    In an equation for `mul':                                                                                                                                                                                                                           
        mul (Con f) y = (Con (\ d -> mul (f d) y)) 

What is the problem?


Solution

  • plus is just a function on values of type Index

    >>> plus Zero Zero
    Zero
    >>> plus Zero (Up Zero)
    Up Zero
    

    so it can't appear in a type signature, as things are. You want to use the 'promoted' type where Zero, Up Zero etc. are types. Then you can write a type function and everything compiles.

    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE ExistentialQuantification #-}
    {-# LANGUAGE TypeOperators #-}
    {-# LANGUAGE TypeFamilies #-}
    
    data Direction = T | X | Y | Z
    data Index = Zero | Up Index | Down Index deriving (Eq, Show)
    
    -- type function Plus
    type family Plus (i :: Index) (j :: Index) :: Index where
      Plus Zero x = x
      Plus (Up x) y  = Up (Plus x y)
      Plus (Down x) y = Down (Plus x y)
    
    -- value fuction plus
    plus :: Index -> Index -> Index
    plus Zero x = x
    plus (Up x) y = Up (plus x y)
    plus (Down x) y = Down (plus x y)
    
    data Tensor (a :: Index) where
      Scalar :: Double -> Tensor Zero
      Cov :: (Direction -> Tensor b) -> Tensor (Up b)
      Con :: (Direction -> Tensor b) -> Tensor (Down b)
    
    add :: Tensor a -> Tensor a -> Tensor a
    add (Scalar x) (Scalar y) = (Scalar (x + y))
    add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
    add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))
    
    mul :: Tensor a -> Tensor b -> Tensor (Plus a b)
    mul (Scalar x) (Scalar y) = (Scalar (x*y))
    mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
    mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
    mul (Cov f) y = (Cov (\d -> mul (f d) y))
    mul (Con f) y = (Con (\d -> mul (f d) y))
    

    There was no ambiguity in Plus but I could have use the disambiguating tick ' to signal that I was dealing with the type level Zero, Up etc.

    type family Plus (i :: Index) (j :: Index) :: Index where
      Plus 'Zero x = x
      Plus ('Up x) y  = 'Up (Plus x y)
      Plus ('Down x) y = 'Down (Plus x y)
    

    TypeOperators would permit you to write a + b rather than Plus a b above.

    type family (i :: Index) + (j :: Index) :: Index where
      Zero + x = x
      Up x + y  = Up (x + y)
      Down x + y = Down (x + y)