Search code examples
coqcoq-tactic

Proving enumerate(...) to range(len(...)) equality


I'm trying to prove the equality of the following Python constructs in Coq: for i, _ in enumerate(l, s) and for i in range(s, len(l) + s)

I've made recursive definitions of both enumerate and range functions, also came up with some additional theorems to aid myself in achieving the original goal and successfully proved them.

However, I got stuck somewhere mid-way. I ended up with the following goal: [] = [s], which is obviously false. That's why I specialize'd nil_cons in attempt to use discriminate, but that didn't help: "No primitive equality found."

I'm sure there should be some other way around, I just don't see it.

That's what I got so far:

   Require Import Coq.Lists.List Coq.Bool.Bool Coq.Arith.Minus.

   Import Coq.Lists.List.ListNotations.
   
   Fixpoint _range (a b d : nat) :=
     match d with
       | 0 => []
       | S n => a :: _range (S a) b n
     end.
   
   Definition range a b := _range a b (b - a).
   
   (*
   Fixpoint less n m : bool :=
     match n, m with
       | 0, 0       => false
       | 0, S _     => true
       | S _, 0     => false
       | S n, S m   => less n m
     end.
  
   Fixpoint range (a b : nat) :=
     if less a b then
        match b with
          | 0 => []
          | S n => range a n ++ [n]
        end
     else [].
   *)
   
   Eval compute in range 0 0.
   Eval compute in range 0 5.
   Eval compute in range 5 0.
   Eval compute in range 5 10.
   
   Theorem range_eq_empty : forall (s : nat),
                            range s s = [].
   Proof.
      induction s. reflexivity. unfold range. rewrite <- minus_n_n. reflexivity.
   Qed.
   
   Fixpoint enumerate (T : Type) (l : list T) (s : nat) :=
     match l with
       | [] => []
       | h :: t => (s, h) :: enumerate T t (S s) 
     end.
   
   Eval simpl in enumerate nat [0;1;2;3;4] 0.
   Eval simpl in enumerate nat [5;6;7;8;9] 5.
   
   Theorem enum_prop_fwd : forall (T : Type) (l : list T) (s : nat),
                           enumerate T l s = [] -> l = [].
   Proof.
      intros. induction l. reflexivity. pose proof nil_cons. symmetry. specialize (H0 T a l). discriminate.
   Qed.
   
   Theorem enum_prop_bwd : forall (T : Type) (l : list T) (s : nat),
                           l = [] -> enumerate T l s = [].
   Proof.
      intros. rewrite H. reflexivity.
   Qed.
   
   Theorem enum_map_prop_fwd : forall (T : Type) (l : list T) (s : nat),
                               enumerate T l s = [] -> map fst (enumerate T l s) = [].
   Proof.
      intros. rewrite H. reflexivity.
   Qed.
   
   Theorem enum_map_prop_bwd : forall (T : Type) (l : list T) (s : nat),
                               map fst (enumerate T l s) = [] -> enumerate T l s = [].
   Proof.
      intros. pose proof (enum_prop_bwd T l s). induction l. reflexivity.
      apply H0. symmetry. pose proof nil_cons. specialize (H1 T a l). discriminate.
   Qed.

   Theorem enum_revert : forall (T : Type) (l : list T) (a : T) (s : nat),
                         s :: map fst (enumerate T l (S s)) = map fst (enumerate T (a :: l) s).
   Proof.
      intros. induction s; reflexivity.
   Qed.
   
   (* for i, _ in enumerate(l, s) = for i in range(s, len(l) + s) *)
   Theorem abc : forall (T : Type) (l : list T) (s : nat),
                 map fst (enumerate T l s) = range s (length l + s).
   Proof.
      intros. induction (length l). induction l. induction s. rewrite range_eq_empty. reflexivity.
      rewrite range_eq_empty. reflexivity. rewrite range_eq_empty. rewrite <- enum_revert. symmetry.
      rewrite range_eq_empty in IHl. pose proof (enum_prop_fwd T l s (enum_map_prop_bwd T l s IHl)). 
      rewrite H. compute. pose proof nil_cons. specialize (H0 nat s []). discriminate.

Solution

  • You can start your proof of abc by just an induction on l (not an induction on length l).

    Arithmetic equalities which appear in the proof are solved with lia.

    
     Require Import Lia. 
    
     Theorem abc : forall (T : Type) (l : list T) (s : nat),
                     map fst (enumerate T l s) = range s (length l + s).
       Proof.
         induction l. 
         - intro s; cbn; now replace (s - s) with 0 by lia. 
         - destruct s as [| s].
           + cbn; rewrite IHl; unfold range; do 2 f_equal; lia.  
           + cbn; rewrite IHl; unfold range; cbn.
             replace (length l + S s - s) with (S (length l)) by lia.
             cbn; do 2 f_equal; lia. 
       Qed.