Search code examples
idrisdependent-typetheorem-proving

How to prove propositional equality between a complicated expression and `True`?


I have some code that looks like this:

allLessThan : Ord t => (v1 : Vect n t) -> (v2 : Vect n t) -> Bool
allLessThan v1 v2 = all (\(x,y) => x < y) (zip v1 v2)

unravelIndexUnsafe : (order : ArrayOrder) ->
                     (shape : ArrayShape (S n)) ->
                     (position : Vect (S n) Nat) ->
                     Nat
unravelIndexUnsafe order shape position = ?someImplementation

unravelIndexSafe : (order : ArrayOrder) ->
                   (shape : ArrayShape (S n)) ->
                   (position : Vect (S n) Nat) ->
                   {auto 0 prfPositionValid : (allLessThan position shape) = True} ->
                   Nat
unravelIndexSafe order shape position = unravelIndexUnsafe order shape position

unravelIndex : (order : ArrayOrder) ->
               (shape : ArrayShape (S n)) ->
               (position : Vect (S n) Nat) ->
               Maybe Nat
unravelIndex order shape position =
  case allLessThan position shape of
    True => Just $ unravelIndexSafe order shape position
    False => Nothing

I omitted the implementation of unravelIndexUnsafe which I think is irrelevant to the question.

I get a type error in the definition of unravelIndex, saying that it can't find an implementation for prfPositionValid to use with unravelIndexSafe*.

This was surprising to me, because I am explicitly case splitting on allLessThan position shape, and only calling unravelIndexSafe in the True branch. I expected that Idris would be able to infer from this information that the proposition (allLessThan position shape) = True holds.

Is there a straightforward way to solve the problem? Maybe something I can explicitly construct and pass for the prfPositionValid implicit argument? Or is there an entirely different approach I should use here? Do I need to express prfPositionValid or allLessThan differently? Do I need to rewrite something?

* More precisely, it can't find an implementation for this monstrous "fully-expanded" version of prfPositionValid:

foldl (\acc, elem => acc && Delay (case block in allLessThan (S n) Nat (MkOrd (\{arg:354}, {arg:355} => compare arg arg) (\{arg:356}, {arg:357} => == (compare arg arg) LT) (\{arg:358}, {arg:359} => == (compare arg arg) GT) (\{arg:360}, {arg:361} => not (== (compare arg arg) GT)) (\{arg:362}, {arg:363} => not (== (compare arg arg) LT)) (\{arg:364}, {arg:365} => if == (compare arg arg) GT then x else y) (\{arg:366}, {arg:367} => if == (compare arg arg) LT then x else y)) shape position elem)) True (zipWith (\{__leftTupleSection:0}, {__infixTupleSection:0} => (__leftTupleSection, __infixTupleSection)) position shape) = True

Solution

  • Solution: use decidable equality

    The answer is to use "decidable equality", because Idris is not as smart as a human.

    Note that the special = syntax is equivalent to the builtin operator (===), which is equivalent to the type Equal. The constructor for Equal is Refl. In order to prove a proposition of the form Equal a b, Idris must be able to figure out that a and b are in fact the same thing (call it c). If you can invoke Refl c with type Equal a b, then you have proven Equal a b. Conversely, the only way to obtain an instance of Equal a b is by invoking Refl c.

    Idris 2 cannot infer propositional equality by case-splitting. I, a human, know that we are trying to show that allLessThan position shape is propositionally equal to True. In Idris, this means we want to be able to write Refl True. Case-splitting on allLessThan position shape does result in a Bool, but this alone does not constitute an invocation of Refl True with type Equal (allLessThan position shape) True. Therefore case-splitting as in the original code is not sufficient for Idris to infer a proof of Equal (allLessThan position shape) True.

    We know that allLessThan position shape is a decidable predicate, so we can use decEq to obtain the proof/implementation that we need. Therefore we can write unravelIndex as:

    unravelIndex : (order : ArrayOrder) ->
                   (shape : ArrayShape (S n)) ->
                   (position : Vect (S n) Nat) ->
                   Maybe Nat
    unravelIndex order shape position =
      case decEq (allLessThan position shape) True of
        Yes proof => Just $ unravelIndexSafe order shape position
        No contra => Nothing
    

    The proof in Yes proof is precisely the Refl True we were looking for, which implements Equal (allLessThan position shape) True. Therefore Idris will be able to infer a value for the prfPositionValid auto-implicit, because a value of the right type is available in scope.

    You could also write _ instead of proof and contra, because the proofs are not explicitly used in the code anywhere, so they don't need names.

    Refactoring

    Note that this allLessThan position shape is somewhat ad-hoc. In particular, stating the conditions of the property requires the programmer to memorize a specific expression. However we would like to write a tidier API, in which the programmer can invoke a function isPositionValidForShape to check validity, and use a type IndexValidForShape to represent the "valid" state.

    allLessThan : Ord t => (v1 : Vect n t) -> (v2 : Vect n t) -> Bool
    allLessThan v1 v2 = all (\(x,y) => x < y) (zip v1 v2)
    
    IndexValidForShape : (shape : ArrayShape ndim) ->
                         (position : ArrayIndex ndim) ->
                         Type
    IndexValidForShape shape position =
      let isValid = allLessThan position shape
      in Equal isValid True
    
    isIndexValidForShape : (shape : ArrayShape (S n)) ->
                           (position : ArrayIndex (S n)) ->
                           Dec (IndexValidForShape shape position)
    isIndexValidForShape shape position =
      decEq (allLessThan position shape) True
    
    unravelIndexUnsafe : (order : ArrayOrder) ->
                         (shape : ArrayShape (S n)) ->
                         (position : ArrayIndex (S n)) ->
                         Nat
    unravelIndexUnsafe order shape position =
      sum $ zipWith (*) (strides order shape) position
    
    unravelIndexSafe : (order : ArrayOrder) ->
                       (shape : ArrayShape (S n)) ->
                       (position : ArrayIndex (S n)) ->
                       {auto 0 prfIndexValid : IndexValidForShape shape position} ->
                       Nat
    unravelIndexSafe order shape position =
      unravelIndexUnsafe order shape position
    
    unravelIndex : (order : ArrayOrder) ->
                   (shape : ArrayShape (S n)) ->
                   (position : ArrayIndex (S n)) ->
                   Maybe Nat
    unravelIndex order shape position =
      case isIndexValidForShape shape position of
        Yes _ => Just $ unravelIndexSafe order shape position
        No _ => Nothing
    

    Now, the end user don't have to know or care what exactly IndexValidForShape entails, or that you need to use allLessThan to check for it.

    In fact, we can now change what it means for an index to be "valid", mostly without affecting downstream code user; maybe there are additional checks I want to put in place, that I only learn about after I find a logic bug.

    Alternatively, it should be possible to re-design IndexValidForShape to be more "structural", wherein you inductively define a data type that represents the desired property. For example, refer to Data.Vect.Elem and its description in Chapter 9 of Type-Driven Development.

    Glossary

    • decidable: "a property is decidable if you can always say whether the property holds for some specific values" (quoted from Type-Driven Development, page 245).
    • Dec: The type representing the validity of a decidable property. Its constructors are:
      • Yes : property -> Dec property - the property holds.
      • No : (property -> Void) -> Dec property - the property is a contradiction.
    • DecEq: The interface for data types for which equality can be determined as a decidable property.
    • decEq: The method of DecEq that determines if two things are decidably equal. Its type is DecEq t => (x1 : t) -> (x2 : t) -> Dec (Equal x1 x2).

    References & Further reading