Search code examples
coq

How to deal with really large terms generated by Program Fixpoint in Coq?


I'm attempting to define and prove correct in Coq a function that efficiently diffs two sorted lists. As it does not always recurse on a structurally smaller term (either the first or second list is smaller), Fixpoint won't accept it, so I'm attempting to use Program Fixpoint instead.

When attempting to prove a property of the function using the tactic simpl or program_simpl, Coq spends minutes computing and then produces a giant term, hundreds of lines long. I was wondering if I'm using Program Fixpoint the wrong way, or alternatively if there are other tactics that should be used instead of simplification when reasoning about it?

I also wondered if it's good practice to include the required properties for correctness in params like this, or would it be better to have a separate wrapper function that takes the correctness properties as params, and make this function just take the two lists to be diffed?

Note that I did try defining a simpler version of make_diff, which only took l1 and l2 as parameters and fixed the type A and relation R, but this still produced a gigantic term when the program_simpl or simpl tactics were applied.

*Edit: my includes are (although they may not all be required here):

Require Import Coq.Sorting.Sorted.
Require Import Coq.Lists.List.
Require Import Coq.Relations.Relation_Definitions.
Require Import Recdef.
Require Import Coq.Program.Wf.
Require Import Coq.Program.Tactics.

The code:

Definition is_decidable (A : Type) (R : relation A) := forall x y, {R x y} + {~(R x y)}.
Definition eq_decidable (A : Type) := forall (x y : A), { x = y } + { ~ (x = y) }.

Inductive diff (X: Type) : Type :=
  | add : X -> diff X
  | remove : X -> diff X 
  | update : X -> X -> diff X.

Program Fixpoint make_diff (A : Type) 
    (R : relation A)
    (dec : is_decidable A R)
    (eq_dec : eq_decidable A)
    (trans : transitive A R) 
    (lt_neq : (forall x y, R x y -> x <> y))
    (l1 l2 : list A)
     {measure (length l1 + length l2) } : list (diff A) :=
  match l1, l2 with
  | nil, nil => nil
  | nil, (new_h::new_t) => (add A new_h) :: (make_diff A R dec eq_dec trans lt_neq nil new_t)
  | (old_h::old_t), nil => (remove A old_h) :: (make_diff A R dec eq_dec trans lt_neq old_t nil)
  | (old_h::old_t) as old_l, (new_h::new_t) as new_l => 
    if dec old_h new_h 
      then (remove A old_h) :: make_diff A R dec eq_dec trans lt_neq old_t new_l
      else if eq_dec old_h new_h 
        then (update A old_h new_h) :: make_diff A R dec  eq_dec trans lt_neq old_t new_t
        else  (add A new_h) :: make_diff A R dec eq_dec trans lt_neq old_l new_t 
  end.
Next Obligation.
Proof.
  simpl.
  generalize dependent (length new_t).
  generalize dependent (length old_t).
  auto with arith.
Defined.
Next Obligation.
Proof.
  simpl.
  generalize dependent (length new_t).
  generalize dependent (length old_t).
  auto with arith.
Defined.

Solution

  • In this particular case we can get rid of Program Fixpoint and use plain simple Fixpoint. Since at each recursive call we invoke make_diff either on the tail of the first list or on the tail of the second list, we can nest two fixed-point functions as follows. (I have used the Section mechanism here to avoid passing too many identical arguments)

    Require Import Coq.Lists.List.
    Import ListNotations.
    Require Import Coq.Relations.Relations.
    
    Section Make_diff.
    
    Variable A : Type.
    Variable R : relation A.
    Variable dec : is_decidable A R.
    Variable eq_dec : eq_decidable A.
    Variable trans : transitive A R.
    Variable lt_neq : forall x y, R x y -> x <> y.
    
    Fixpoint make_diff (l1 l2 : list A) : list (diff A) :=
      let fix make_diff2 l2 :=
      match l1, l2 with
      | nil, nil => nil
      | nil, new_h::new_t => (add A new_h) :: make_diff2 new_t
      | old_h::old_t, nil => (remove A old_h) :: make_diff old_t nil
      | old_h::old_t, new_h::new_t =>
        if dec old_h new_h 
        then (remove A old_h) :: make_diff old_t l2
        else if eq_dec old_h new_h 
             then (update A old_h new_h) :: make_diff old_t new_t
             else (add A new_h) :: make_diff2 new_t
      end
      in make_diff2 l2.
    
    End Make_diff.
    

    Observe that the Section mechanism won't include unused parameters in the resulting signature. Here is a naive test:

    (* make the first 2 arguments implicit *)    
    Arguments make_diff [A R] _ _ _ _.
    
    Require Import Coq.Arith.Arith.
    
    Compute make_diff lt_dec Nat.eq_dec [1;2;3] [4;5;6].
    (* = [remove nat 1; remove nat 2; remove nat 3; add nat 4; add nat 5; add nat 6] 
          : list (diff nat) *)