As part of a program that I'm writing, I need to compare two values in the form a + sqrt(b)
where a
and b
are unsigned integers. As this is part of a tight loop, I'd like this comparison to run as fast as possible. (If it matters, I'm running the code on x86-64 machines, and the unsigned integers are no larger than 10^6. Also, I know for a fact that a1<a2
.)
As a stand-alone function, this is what I'm trying to optimize. My numbers are small enough integers that double
(or even float
) can exactly represent them, but rounding error in sqrt
results must not change the outcome.
// known pre-condition: a1 < a2 in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
return a1+sqrt(b1) < a2+sqrt(b2); // computed mathematically exactly
}
Test case: is_smaller(900000, 1000000, 900001, 998002)
should return true, but as shown in comments by @wim computing it with sqrtf()
would return false. So would (int)sqrt()
to truncate back to integer.
a1+sqrt(b1) = 90100
and a2+sqrt(b2) = 901000.00050050037512481206
. The nearest float to that is exactly 90100.
As the sqrt()
function is generally quite expensive even on modern x86-64 when fully inlined as a sqrtsd
instruction, I'm trying to avoid calling sqrt()
as far as possible.
Removing sqrt by squaring potentially also avoids any danger of rounding errors by making all computation exact.
If instead the function was something like this ...
bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
return a1+sqrt(b1) < x;
}
... then I could just do return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;
But now since there are two sqrt(...)
terms, I cannot do the same algebraic manipulation.
I could square the values twice, by using this formula:
a1 + sqrt(b1) = a2 + sqrt(b2)
<==> a1 - a2 = sqrt(b2) - sqrt(b1)
<==> (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==> (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==> (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==> ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==> ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2
Unsigned division by 4 is cheap because it is just a bitshift, but since I square the numbers twice I will need to use 128-bit integers and I will need to introduce a few >=0
checks (because I'm comparing inequality instead of equality).
It feels like there might be a way do this faster, by applying better algebra to this problem. Is there a way to do this faster?
Here's a version without sqrt
, though I'm not sure whether it is faster than a version which has only one sqrt
(it may depend on the distribution of values).
Here's the math (how to remove both sqrts):
ad = a2-a1
bd = b2-b1
a1+sqrt(b1) < a2+sqrt(b2) // subtract a1
sqrt(b1) < ad+sqrt(b2) // square it
b1 < ad^2+2*ad*sqrt(b2)+b2 // arrange
ad^2+bd > -2*ad*sqrt(b2)
Here, the right side is always negative. If the left side is positive, then we have to return true.
If the left side is negative, then we can square the inequality:
ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2
The key thing to notice here is that if a2>=a1+1000
, then is_smaller
always returns true
(because the maximum value of sqrt(b1)
is 1000). If a2<=a1+1000
, then ad
is a small number, so ad^4
will always fit into 64 bit (there is no need for 128-bit arithmetic). Here's the code:
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
int ad = a2 - a1;
if (ad>1000) {
return true;
}
int bd = b2 - b1;
if (ad*ad+bd>0) {
return true;
}
int ad2 = ad*ad;
return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
EDIT: As Peter Cordes noticed, the first if
is not necessary, as the second if handles it, so the code becomes smaller and faster:
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
int ad = a2 - a1;
int bd = b2 - b1;
if ((long long int)ad*ad+bd>0) {
return true;
}
int ad2 = ad*ad;
return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}