Search code examples
dafny

Weakest preconditions for higher order recursive function


I am trying to play around with partial functions and higher order functions in Dafny. I wrote funpow(n,f,t) that apply the function f n times to the value t.

function funpow<T(!new)>(n:nat, f: T ~> T,t: T): T
requires forall i :: 0 <= i < n ==> f.requires(funpow(i,f,t))
reads f.reads
decreases n
{
  if n == 0
  then t
  else f(funpow(n-1,f,t))
}

I would like to express the most general precondition forall i :: 0 <= i < n ==> f.requires(funpow(i,f,t)) and Dafny does not like this.

The issue above is that we need to show Dafny that funpow.requires(i,f,t) is satisfied when calling funpow(i,f,t). I changed it to forall i :: 0 <= i < n ==> funpow.requires(i,f,t) && f.requires(funpow(i,f,t)), and now I got Error: cannot use naked function in recursive setting. Possible solution: eta expansion.

I then tried n > 0 ==> funpow.requires(n-1,f,t) && f.requires(funpow(n-1,f,t)) as the precondition. I got the naked function error again.

I then lookup the reference manual and found that there is an {:autoReq} flag that generate the precondition automatically. However, it does not seem to handle recursion.

I am currently using dafny version 3.7.1 .

How should I make this work?


Solution

  • I have finally solve this problem by using predicate to represent the function as a relation. We can now then use the predicate to express the preconditions that we want.

    predicate funpow_pred<T(!new)>(n:nat, f: T ~> T,t: T,t':T)
    reads f.reads
    {
      (n == 0 && t' == t) ||
      (n > 0 && exists t2 :: funpow_pred(n-1,f,t,t2) &&
        f.requires(t2) && f(t2) == t')
    }
    
    // uniqueness
    lemma funpow_pred_uniq<T(!new)>(n:nat,f: T ~> T, t:T, t1:T, t2:T)
    requires funpow_pred(n,f,t,t1)
    requires funpow_pred(n,f,t,t2)
    ensures t1 == t2
    {
      if n == 0
      {
        assert t1 == t;
        assert t2 == t;
      }
      else
      {
        var t3 :| funpow_pred(n-1,f,t,t3);
        forall t4 | funpow_pred(n-1,f,t,t4)
        ensures funpow_pred(n-1,f,t,t4) ==> t4 == t3
        {
          funpow_pred_uniq(n-1,f,t,t3,t4);
        }
        assert t1 == f(t3);
        assert t2 == f(t3);
      }
    }
    
    // use the predicates to state the preconditions
    function funpow<T(!new)>(n:nat,f:T~>T,t:T): T
    reads f.reads
    requires n > 0 ==> exists t'' :: funpow_pred(n-1,f,t,t'') && f.requires(t'')
    ensures funpow_pred(n,f,t,funpow(n,f,t))
    ensures forall t'' :: funpow_pred(n,f,t,t'') ==> t'' == funpow(n,f,t)
    {
      if n == 0 then t
      else
        ghost var t'' :| funpow_pred(n-1,f,t,t'') && f.requires(t'');
        funpow_pred_uniq(n-1,f,t,funpow(n-1,f,t),t'');
        assert funpow(n-1,f,t) == t'';
        f(funpow(n-1,f,t))
    }
    
    // check that this is exactly what I want
    lemma funpow_req_lem<T(!new)>(n:nat,f:T~>T,t:T)
    ensures (n > 0 ==> funpow.requires(n-1,f,t) && f.requires(funpow(n-1,f,t))) == funpow.requires(n,f,t)
    {
    }
    
    predicate inrange(i:nat,n:nat)
    {
      0 <= i < n
    }
    
    predicate exists_funpow_pred<T(!new)>(i:nat,f: T ~> T, t:T)
    reads f.reads
    {
      exists t' :: funpow_pred(i,f,t,t') && f.requires(t')
    }
    
    lemma funpow_pred_forall<T(!new)>(n:nat, f: T ~> T,t: T, t': T)
    requires funpow_pred(n,f,t,t')
    ensures forall i {:trigger inrange(i,n)} :: 0 <= i < n ==> exists_funpow_pred(i,f,t)
    {
      if n == 0
      {
      }
      else
      {
        forall i | 0 <= i < n
        ensures exists_funpow_pred(i,f,t)
        {
          if i == n - 1
          {
            var t' :| funpow_pred(n-1,f,t,t') && f.requires(t');
            assert funpow_pred(i,f,t,t') && f.requires(t');
          }
          else
          {
            var t' :| funpow_pred(n-1,f,t,t') && f.requires(t');
            funpow_pred_forall(n-1,f,t,t');
            assert inrange(i,n-1);
            assert exists t'' :: funpow_pred(i,f,t,t'') && f.requires(t'');
          }
        }
      }
    }
    
    // Validate
    function foo(n:int):int
    requires n < 2
    {
      n + 1
    }
    
    method Validator()
    {
      assert funpow(0,foo,0) == 0;
      funpow_req_lem(1,foo,0);
      assert funpow(1,foo,0) == 1;
      funpow_req_lem(2,foo,0);
      assert funpow(2,foo,0) == 2;
      assert !foo.requires(2);
    }
    

    p.s. Making the function total as suggested would make life much easier.

    datatype option<T> = Some(t:T) | None
    
    function funpow<T(!new)>(n:nat,f:T ~> option<T>,t:T): option<T>
    reads f.reads
    requires forall x :: f.requires(x) // make it total
    {
      if n == 0 then Some(t)
      else
        match funpow(n-1,f,t)
          case Some(y) => f(y)
          case None => None
    }