I'm trying to create some functions to work with the following type. The following code uses the singletons and constraints libraries on GHC-8.4.1:
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
import Data.Constraint ((:-))
import Data.Singletons (sing)
import Data.Singletons.Prelude (Sing(SEQ, SGT, SLT), (%+), sCompare)
import Data.Singletons.Prelude.Num (PNum((+)))
import Data.Singletons.TypeLits (SNat)
import GHC.TypeLits (CmpNat, Nat)
data Foo where
Foo
:: forall (index :: Nat) (len :: Nat).
(CmpNat index len ~ 'LT)
=> SNat index
-> SNat len
-> Foo
This is a GADT that contains a length and an index. The index is is guaranteed to be less than the length.
It is easy enough to write a function to create a Foo
:
createFoo :: Foo
createFoo = Foo (sing :: SNat 0) (sing :: SNat 1)
However, I'm having trouble writing a function that increments the len
while keeping the index
the same:
incrementLength :: Foo -> Foo
incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))
This is failing with the following error:
file.hs:34:34: error:
• Could not deduce: CmpNat index (len GHC.TypeNats.+ 1) ~ 'LT
arising from a use of ‘Foo’
from the context: CmpNat index len ~ 'LT
bound by a pattern with constructor:
Foo :: forall (index :: Nat) (len :: Nat).
(CmpNat index len ~ 'LT) =>
SNat index -> SNat len -> Foo,
in an equation for ‘incrementLength’
at what5.hs:34:17-29
• In the expression: Foo index (len %+ (sing :: SNat 1))
In an equation for ‘incrementLength’:
incrementLength (Foo index len)
= Foo index (len %+ (sing :: SNat 1))
• Relevant bindings include
len :: SNat len (bound at what5.hs:34:27)
index :: SNat index (bound at what5.hs:34:21)
|
34 | incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This makes sense, since the compiler knows that CmpNat index len ~ 'LT
(from the definition of Foo), but doesn't know that CmpNat index (len + 1) ~ 'LT
.
Is there any way to get something like this to work?
It is possible to use sCompare
to do something like this:
incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
case sCompare index (len %+ (sing :: SNat 1)) of
SLT -> Foo index (len %+ (sing :: SNat 1))
SEQ -> error "not eq"
SGT -> error "not gt"
However, it seems unfortunate that I have to write cases for SEQ
and SGT
, when I know they will never be matched.
Also, I thought it might be possible to create a type like the following:
plusOneLTProof :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof = undefined
However, this gives an error like the following:
file.hs:40:8: error:
• Couldn't match type ‘CmpNat n0 m0’ with ‘CmpNat n m’
Expected type: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
Actual type: (CmpNat n0 m0 ~ 'LT) :- (CmpNat n0 (m0 + 1) ~ 'LT)
NB: ‘CmpNat’ is a non-injective type family
The type variables ‘n0’, ‘m0’ are ambiguous
• In the ambiguity check for ‘bar’
To defer the ambiguity check to use sites, enable AllowAmbiguousTypes
In the type signature:
bar :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
|
40 | bar :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This makes sense, I guess, since CmpNat is non-injective. However, I know that this implication is true, so I'd like to be able to write this function.
I'd like an answer to the following two questions:
Is there a way to write incrementLength
where you'd only have to match on SLT
? I'd be fine with changing the definition of Foo
to make this easier.
Is there a way to write plusOneLTProof
, or at least something similar?
Update: I ended up using the suggestion from Li-yao Xia to write plusOneLTProof
and incrementLength
like the following:
incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
case plusOneLTProof index len of
Sub Dict -> Foo index (len %+ (sing :: SNat 1))
plusOneLTProof :: forall n m. SNat n -> SNat m -> (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof SNat SNat = Sub axiom
where
axiom :: CmpNat n m ~ 'LT => Dict (CmpNat n (m + 1) ~ 'LT)
axiom = unsafeCoerce (Dict :: Dict (a ~ a))
This requires that you pass in two SNat
s to plusOneLTProof
, but it doesn't require AllowAmbiguousTypes
.
The compiler is rejecting plusOneLTProof
because its type is ambiguous. We can disable that constraint with the extension AllowAmbiguousTypes
. I would recommend using that together with ExplicitForall
(which is implied by ScopedTypeVariables
, that we'll certainly need anyway, or RankNTypes
). That's for defining it. A definition that has an ambiguous type can be used with TypeApplications
.
However, GHC still can't reason about naturals, so we can't define plusOneLTProof = Sub Dict
, much less incrementLength
, not safely.
But we can still create the proof out of thin air with unsafeCoerce
. This is in fact how the Data.Constraint.Nat
module in constraints is implemented; unfortunately it currently doesn't contain any facts about CmpNat
. The coercion works because there is no runtime content in type equalities. Even if the runtime values look fine, thus asserting incoherent facts can still lead to programs to go wrong.
plusOneLTProof :: forall n m. (CmpNat n m ~ 'LT) :- (CmpNat n (m+1) ~ 'LT)
plusOneLTProof = Sub axiom
where
axiom :: (CmpNat n m ~ 'LT) => Dict (CmpNat n (m+1) ~ 'LT)
axiom = unsafeCoerce (Dict :: Dict (a ~ a))
To use the proof, we specialize it (with TypeApplications
) and pattern match on it to introduce the RHS in the context, checking that the LHS holds.
incrementLength :: Foo -> Foo
incrementLength (Foo (n :: SNat n) (m :: SNat m)) =
case plusOneLTProof @n @m of
Sub Dict -> Foo n (m %+ (sing :: SNat 1))