Search code examples
dafny

Prove that multiplying and dividing by same number is identity


I'm stuck with trying to prove the following lemma:

lemma MulDivIsId(a: nat, b: nat)
requires b > 0
ensures a * b / b == a
{
}

I tried to help the verifier by showing that b over b is 1, and also tried mentioning that a * b / b equals a * (b / b), which it however couldn't prove. How can I prove this? Are there general steps that usually help when proving lemmas involving integer division/natural number division (which in Dafny is truncating division)?


Solution

  • This particular lemma is proved in Dafny standard library here. This requires stating additional lemma about multiplication and division. But hopefully proof is easy to follow.


    Edit: See code listed below to prove this manually

    lemma DivModLemma(k: nat, n: nat)
       requires n > 0
       ensures n * (k / n) + k % n == k
    {}
    
    lemma ZeroLemma(a: nat, b: nat)
      requires a > 0
      requires a * b == 0
      ensures b == 0
    {}
    
    lemma CancelLemma(a: nat, b: nat, c: nat)
       requires c > 0
       requires c * a == c * b
       ensures a == b
    {
      if c == 1 {}
      else {
        calc {
          c * a  - c * b == 0;
          c * (a - b) == 0;
          { ZeroLemma(c, a - b); }
          a - b == 0;
          a == b;
        }
      }
    }
    
    lemma ModZeroLemma(a: nat, b: nat)
      requires b > 0
      requires a % b == 0
      ensures a == (a / b) * b
    {}
    
    lemma ExactDivisionLemma(a: nat, b: nat)
      requires b > 0
      requires exists m :: m * b == a
      ensures a % b == 0
    {
      DivModLemma(a, b);
      assert b * (a / b) + a % b == a;
      var m :| m * b == a;
    
      assert b * (a / b) + a % b == m * b;
      assert ((a / b) - m) * b + a % b == 0;
      assert 0 <= a % b < b;
    
      if a / b - m == 0 {}
      else if a / b - m < 0 {
        assert (a / b - m) <= -1;
      }
      else {
        assert (a / b - m) >= 1;
      }
    }
    
    lemma ReminderLemma(a: nat, b: nat, c : nat)
      requires c > 0
      requires a % c == 0 && b % c == 0
      ensures (a + b) % c == 0
    {
      ModZeroLemma(a, c);
      ModZeroLemma(b, c);
      assert a + b == (a / c) * c + (b / c) * c;
      assert a + b == (a / c + b / c) * c;
      ExactDivisionLemma(a + b, c);
    }
    
    lemma ModLemma(a: nat, b: nat, c: nat)
      requires c > 0
      requires a % c == 0 && b % c == 0
      ensures (a + b) / c == a / c + b / c
    {
      DivModLemma(a + b, c);
      assert c * ((a + b) / c) + (a + b) % c == (a + b);
      ReminderLemma(a, b, c);
      assert c * ((a + b) / c) == (a + b);
      DivModLemma(a, c);
      assert c * (a / c) + a % c == a;
      assert c * (a / c) == a;
      DivModLemma(b, c);
      assert c * (b / c) + b % c == b;
      assert c * (b / c) == b;
      assert c * (a / c) + c * (b / c) == a + b;
      assert c * (a / c + b / c) == a + b;
      assert c * ((a + b) / c) == c * (a / c + b / c);
      CancelLemma((a + b) / c, a / c + b / c, c);
    }
    
    lemma MulDivId(a: nat, b: nat)
      requires b > 0
      ensures (a * b) / b == a
    {
      if a == 0 {}
      else {
        calc {
           (a * b) / b;
           ((a - 1 + 1) * b ) / b;
           ((a - 1) * b + b ) / b;
           {
             ExactDivisionLemma((a-1) * b, b);
             ModLemma((a - 1) * b, b, b);
           }
           (a - 1) * b / b + b / b;
           { MulDivId(a - 1, b); }
           a - 1 + 1;
           a;
        }
      }
    }