Search code examples
recursioncoqcoqide

How to avoid "Cannot guess decreasing argument of fix." in Coq


(* Define a function sum_digits such that its input is an int, the base of which is b and its
   output is the sum of all its digits *)


Definition sum_digits (digits : nat) (base : nat) : nat :=
  let fix aux (sum : nat) (x_value : nat) : nat :=
    if leb x_value 0 then
      sum
    else
      aux (add sum (modulo x_value base)) (div x_value base) in
  aux 0 digits.

Error: Cannot guess decreasing argument of fix.

I Defined a function according to the description above, but it seems that Coq cannot deal with this kinds of recursion. How could I fix this problem?


Solution

  • You can use the Program library. It can sometimes be a little tricky to use, but in this case it works fine. You enter your program, and then you get a bunch of obligations to solve. Next Obligation. gives you the next obligation to solve.

    Here we say {measure digits} which means that digits will be strictly decreasing. This will guarantee that the execution eventually terminates.

    Require Import Program Nat.
    
    Program Fixpoint sum_digits (digits : nat) (base : nat) (sum : nat) {measure digits} :=
    if dec (eqb digits 0) then sum 
    else sum_digits (digits/base) base (sum + modulo digits base).
    
    Next Obligation.
    

    Note that I write dec (eqb digits 0) so that we don't forget that digits != 0 in the second branch.

    When we do Next Obligation. we get the goal

    digits / base < digits
    

    Let's search for a proof of this.

    Search (?a / ?b < ?a).
    
    ==>  PeanoNat.Nat.div_lt: forall a b : nat, 0 < a -> 1 < b -> a / b < a
    

    Ouch. This is something that we can't prove unless we also know that 1 < base. Let's add an argument Hbase: 1<base.

    Change that, and apply div_lt. Finally use (digits =? 0) = false to prove that 0 < digits. Search for something that can help us.

    Search ((?a =? ?b) = false).
    
    ==> PeanoNat.Nat.eqb_neq: forall x y : nat, (x =? y) = false <-> x <> y
    

    Lets apply that. And we're essentially done.

    Here is the complete session.

    Require Import Program Nat.
    
    Program Fixpoint sum_digits (digits : nat) (base :nat) (Hbase: 1 < base) (sum : nat) {measure digits} :=
    if dec (eqb digits 0) then sum else sum_digits (digits/base) base Hbase (sum + modulo digits base).
    
    Next Obligation.
      apply PeanoNat.Nat.div_lt.
      apply PeanoNat.Nat.eqb_neq in e.
      destruct digits. exfalso; congruence. apply PeanoNat.Nat.lt_0_succ.
      apply Hbase.
    Defined.
    

    Try it with

    Compute sum_digits 758 2 _ 0.
    
     = 7
     : nat
    

    ADDENDUM:

    When you build terms with Program you end up with a godawful term that is very annoying to reason with. Therefore you need a lemma about your proof term that you can use to rewrite with.

    It is usually very hard to see what needs to be done when you are proving this lemma, but usuall you mostly have to do case splits to split the match statements into finer details.

    Here is an example of the lemma that you would need for this function. It may be useful to look at the proof if you need to do use Program with some other function.

    Lemma sum_digits_eq digits base (Hbase: 1< base) sum:
    sum_digits digits base Hbase sum = if dec (eqb digits 0) then sum else sum_digits (digits/base) base Hbase (sum + modulo digits base).
    
    unfold sum_digits, sum_digits_func.
    rewrite fix_sub_eq.
    - fold sum_digits_func; simpl.
      now destruct dec.
    - intros.
      destruct dec; simpl.
      + reflexivity.
      + f_equal.
        f_equal.
        unfold sum_digits_func_obligation_1.
        f_equal.
        destruct x; simpl.
        destruct x; simpl.
        reflexivity.
        reflexivity.
    Qed.
    

    With this function it is easier to prove other properties about the function, i.e.

    Require Import Lia.
    
    Lemma sum_digits_acc digits base (Hbase: 1< base) sum:
    sum_digits digits base Hbase sum = sum + sum_digits digits base Hbase 0.
    
      revert sum.
      induction digits using Wf_nat.lt_wf_ind.
      rename H into IH; intros.
    
      rewrite sum_digits_eq.
      apply eq_sym.
      rewrite sum_digits_eq.
    
      destruct digits.
      - simpl. lia.
      - destruct dec.
        + lia.
        + rewrite IH. 2: apply PeanoNat.Nat.div_lt; lia.
          apply eq_sym. 
          rewrite IH. 2: apply PeanoNat.Nat.div_lt; lia.
          lia.
    Qed.