Search code examples
ats

How do you replace static assertions with prfuns?


Consider this unrefined (but working) program:

#include "share/atspre_staload.hats"

datatype class =
        | mage   | fighter | thief | cleric
        | wizard | warrior | ninja | priest

fn promoteclass(job: class): class =
        case- job of
        | mage() => wizard()
        | fighter() => warrior()
        | thief() => ninja()
        | cleric() => priest()

fn getsomeclass(): class = mage()

val- wizard() = promoteclass(getsomeclass())

implement main0() = ()

it's a runtime error to pass wizard() to promoteclass(), and it's a runtime error if promoteclass(getsomeclass()) is changed to return something other than a wizard().

Which is no good! I'd much rather flip both of those - signs to + and get compile time errors in both of the previous two error cases. It'd also be nice if it could be a compile-time error to accidentally transpose a promotion case, to say priest() => cleric()

This desire led a refinement of the above, which also works just fine:

#include "share/atspre_staload.hats"

datatype class(int) =
        | mage(0)   | fighter(1) | thief(2) | cleric(3)
        | wizard(4) | warrior(5) | ninja(6) | priest(7)

fn promoteclass{n:int | n < 4}(job: class(n)): [m:int | m == n + 4] class(m) =
        case+ job of
        | mage() => wizard()
        | fighter() => warrior()
        | thief() => ninja()
        | cleric() => priest()

fn getsomeclass(): class(0) = mage()

val+ wizard() = promoteclass(getsomeclass())

implement main0() = ()

But what I'd like to do is replace the the n < 4 and such above with dataprops and proof functions. Is that possible? Mainly I want to do that to better understand theorem-proving in ATS, but it also seems that this is the path to getting the same guarantees as the second example without all of its verbosity (especially as additional functions are added, that operate on these classes).

This is what I tried to do:

#include "share/atspre_staload.hats"

datatype class(int) =
        | mage(0)   | fighter(1) | thief(2) | cleric(3)
        | wizard(4) | warrior(5) | ninja(6) | priest(7)

dataprop promotable(int) =
        | {n:int}promotable_yes(n)
        | {n:int}promotable_no(n)

prfun test_promotable.<>.{n:int}():<> promotable(n) =
        sif n < 4 then promotable_yes{n}() else promotable_no{n}()

fn promoteclass{n:int}(job: class(n)): [m:int] class(m) =
        let
                prval promotable_yes() = test_promotable{n}()
        in
                case+ job of
                | mage() => wizard()
                | fighter() => warrior()
                | thief() => ninja()
                | cleric() => priest()
        end

fn getsomeclass(): class(0) = mage()

val+ wizard() = promoteclass(getsomeclass())

implement main0() = ()

But right away I'm told that the prval assignment is non-exhaustive.


Solution

  • The following code should fix the erasure-error:

    fn promoteclass{n:int}
      (pf: promotable(n) | job: class(n)): [m:int] class(m) =
      (
        case+ job of
        | mage() => wizard()
        | fighter() => warrior()
        | thief() => ninja()
        | cleric() => priest()
        | _ =/=>> () where
          {
            prval () =
            (
            case+ pf of
            | pf_mage() => ()
            | pf_fighter() => ()
            | pf_thief() => ()
            | pf_cleric() => ()
            ) : [false] void
          }
      )
    

    You can move the proof code into a proof function; the type for the proof function is a bit complex. Here is what I have:

    prfn
    not_promotable
      {n:int | n != 0&&n != 1&&n != 2&&n != 3 }
      (pf: promotable(n)):<> [false] void =
    (
    case+ pf of
    | pf_mage() => ()
    | pf_fighter() => ()
    | pf_thief() => ()
    | pf_cleric() => ()
    )
    
    fn promoteclass{n:int}
      (pf: promotable(n) | job: class(n)): [m:int] class(m) =
      (
        case+ job of
        | mage() => wizard()
        | fighter() => warrior()
        | thief() => ninja()
        | cleric() => priest()
        | _ =/=>> () where { prval () = not_promotable(pf) }
      )