I was writing code for a loop-based square root estimation method in Dafny, this is what I have:
method sqrt(val :int) returns (root:int)
requires val >= 0
ensures root * root >= val && (root - 1) * (root - 1) < val
{
root := 0;
var est := val;
while (est > 0)
invariant root * root >= val - est
invariant (root-1) * (root-1) < val
decreases est
{
root := root + 1;
est := est - (2 * root - 1);
}
}
Dafny is unable to verify this program due to the loop invariant. I can kind of see why, I am assuming its because the root can be 0 and therefore (0-1) * (0-1) < n could be false if n was 0, but I can't see a solution as I am new to this Dafny stuff. With the code described above I am getting an error when verifying which is:
src\dafnypractice.dfy(9,34): Error: this loop invariant could not be proved on entry Related message: loop invariant violation | 9 | invariant (root-1) * (root-1) < val | ^
src\dafnypractice.dfy(9,34): Error: this invariant could not be proved to be maintained by the loop Related message: loop invariant violation | 9 | invariant (root-1) * (root-1) < val | ^
Any help is appreciated.
Yes you identified a case which isn't true. You could specify it if val > 1
, however you will need to develop those cases. You can make it work with a small change by making it less than or equal to val. Verifying loop invariants is all about induction. In Dafny every variable you're using has to be defined by the invariants and inductively the variable's value and relationships are maintained.
function toOdd(n: nat): nat
requires n > 0
{
2*n-1
}
function SumOfNOddNumbers(n: nat): nat {
if n == 0 then 0 else toOdd(n)+SumOfNOddNumbers(n-1)
}
lemma SumOddIsSquared(n: nat)
ensures SumOfNOddNumbers(n) == n*n
{}
method sqrt(val :nat) returns (root:nat)
ensures val == 0 ==> root == 0
ensures val != 0 ==> (root - 1) * (root - 1) <= val
ensures val != 0 ==> root * root >= val
{
root := 0;
var est: int := val;
while (est > 0)
invariant val == 0 ==> root == 0
// invariant est == val ==> root == 0
invariant est == val-SumOfNOddNumbers(root)
invariant root * root >= val - est
invariant val != 0 ==> (root-1) * (root-1) <= val
decreases est
{
root := root + 1;
ghost var oldEst := est;
est := est - (2 * root - 1);
assert val != 0 ==> (root-1) * (root-1) <= val by {
// assert oldEst == val - SumOfNOddNumbers(root-1);
assert oldEst > 0;
SumOddIsSquared(root-1);
}
}
}
}
More complete version:
module SOSqrt {
function toOdd(n: nat): nat
requires n > 0
{
2*n-1
}
function SumOfNOddNumbers(n: nat): nat {
if n == 0 then 0 else toOdd(n)+SumOfNOddNumbers(n-1)
}
lemma SumOddIsSquared(n: nat)
ensures SumOfNOddNumbers(n) == n*n
{}
lemma SquareOfGreatNLarger(root: nat, n: nat)
requires n > root
ensures SumOfNOddNumbers(n) > SumOfNOddNumbers(root)
{}
lemma lessSquared(a: nat, b: nat)
requires a <= b
ensures a*a <= b*b
{
if a == b {
assert a*a <= b*b;
}else{
var diff := b-a;
assert diff > 0;
calc {
b*b;
(a+diff)*(a+diff);
a*a +2*a*diff + diff * diff;
}
assert a*a <= b*b;
}
}
method sqrt(val :nat) returns (root:nat)
ensures val == 0 ==> root == 0
ensures val == 1 ==> root == 1
ensures val != 0 ==> root * root >= val
ensures val != 0 ==> (root - 1) * (root - 1) < val
{
root := 0;
var est: int := val;
while (est > 0)
invariant val == 0 ==> root == 0
invariant est == val - SumOfNOddNumbers(root)
invariant root * root >= val - est
invariant val > 1 ==> (root-1) * (root-1) < val
invariant val == 1 ==> (root-1) * (root-1) <= val
invariant est <= 0 ==> forall n :nat :: n > root ==> est > val - SumOfNOddNumbers(n)
invariant est <= 0 ==> forall n: nat :: n < root ==> val - SumOfNOddNumbers(n) > 0
decreases est
{
root := root + 1;
ghost var oldEst := est;
est := est - (2 * root - 1);
assert val > 1 ==> (root-1) * (root-1) < val by {
assert oldEst > 0;
SumOddIsSquared(root-1);
}
assert est <= 0 ==> forall n :nat :: n > root ==> est > val-SumOfNOddNumbers(n) by {
forall n | n > root
ensures est > val-SumOfNOddNumbers(n)
{
assert n >= root +1;
SquareOfGreatNLarger(root, n);
}
}
assert est <= 0 ==> forall n :nat :: n < root ==> val-SumOfNOddNumbers(n) > 0 by {
if est <= 0 {
assert val >= 1;
forall n : nat | n < root
ensures val-SumOfNOddNumbers(n) > 0
{
SumOddIsSquared(n);
lessSquared(n, root-1);
assert n*n <= (root-1) * (root-1);
}
}
}
}
}
}