Search code examples
javascriptalgorithmbit-manipulation

Find a solution to a Math.imul() when given a known result


I know the result of a math.imul statement and would like to find one of the solutions that gives the known result. How would I solve for a value of x that works?

Math.imul( x , 2654435761); // the result of this statement is 1447829970

Solution

  • I'm gonna put the unit test and benchmark of different solutions at the end of this post. It currently includes my two functions, and @harold's solution.

    My original answer

    Basically, you want to find x in (x * 2654435761) % 232 == 1447829970

    There's an iterative solution for (x * y) % 2n == z with known y and z. It is based on the fact that any odd number is coprime with any 2n, which means there has to be exactly 1 solution for any odd y and arbitrary z and n.

    Calculation therefore starts with modulo 21: ((x % 21) * (y % 21)) % 21 == z % 21

    Only 2 solutions for x % 21 exist: 0 and 1. The algorithm picks the one which works. Then goes up by a power of 2, where the second possible solution will be equal to the last one increased by 21. Again, the valid one is picked, and the process proceeds until modulo 232 is reached.

    However, if y is even, then z must have at least as many trailing zeros as y has. Otherwise, there's no solution.

    PS: This works in O(n) time with fixed-width integers. But a variation for arbitrary-width integers (such as BigInt) may explode up to O(n3) depending on how multiplication is implemented.

    // seriously though, I have no idea how to call this function!
    function unmultiply (y, z) {
        y >>>= 0;
        z >>>= 0;
        let bits = 1;
        while (bits <= 32 && (y & (1 << (bits - 1))) == 0) {
            if (z & (1 << (bits - 1))) {
                // there's no solution
                return NaN;
            }
            bits++;
        }
        let shift = 1, x1 = 0, x2 = 1;
        while (bits < 32) {
            let mask = 0xffffffff >>> (32 - bits);
            if ((Math.imul(x1, y & mask) & mask) == (z & mask)) {
                x2 = x1 | (1 << shift);
            }
            else {
                x1 = x2;
                x2 = x2 | (1 << shift);
            }
            bits++;
            shift++;
        }
        return ((Math.imul(x1, y) >>> 0) == z)
            ? x1 >>> 0
            : x2 >>> 0;
    }
    

    Multiplication-free version

    On the next day after posting the original version, I realized that the same approach can be implemented multiplication-free. It only requires additions, shifts, and comparisons. However, it comes at the cost of more operations per cycle, so it may run slower because integer multiplications are already quite cheap on the modern hardware. But if the same idea is applied to arbitrary-width integers, it should work in just O(n2) time, so it is reasonable to consider.

    function unmultiply (y, z) {
        y >>>= 0;
        z >>>= 0;
        if (y == 0) {
            return (z == 0)? 0: NaN;
        }
        // thanks to @harold for the trick with clz
        let ybit = Math.clz32(y & -y) ^ 31;
        if (ybit && (z << (32 - ybit))) {
            // there's no solution
            return NaN;
        }
        let xbit = 0, xmask = 0, x = 0, zacc = 0;
        let ymask = 0xffffffff >>> (31 - ybit);
        while (ybit < 32) {
            if (y & (1 << ybit)) {
                zacc = (zacc + ((x & xmask) << ybit)) >>> 0;
            }
            if ((zacc & ymask) != (z & ymask)) {
                zacc = (zacc + ((y & ymask) << xbit)) >>> 0;
                x |= 1 << xbit;
            }
            xmask |= 1 << xbit;
            ymask |= 2 << ybit;
            xbit++;
            ybit++;
        }
        return x >>> 0;
    }
    

    Unit test and benchmark

    Unit test is partially exhaustive, partially randomized. The entire test set is divided into 33 subsets. Each subset is defined by the parameter tz ranging from 0 to 32 which tells the amount of trailing zeros in y (z may have same or more trailing zeros, but not less). I've set the default number of test samples per subset to 20K. For any given subset, if its number of all possible combinations of y and z is <= 20K, then it will be tested exhaustively. Otherwise, the subset is populated by 20K randomly generated unique pairs of y and z.

    Some unsolvable combinations are tested as well, but those are procedurally generated constants.

    I also added y >>>= 0; z >>>= 0; to each test function just to comply with how JS natively handles values that are expected to be Uint32.

    Testing on Firefox 113/Win10 shows that the @harold's implementation is the fastest.

    let solutions = {
    
        blakkwater1: (y, z) => {
            y >>>= 0;
            z >>>= 0;
            let bits = 1;
            while (bits <= 32 && (y & (1 << (bits - 1))) == 0) {
                if (z & (1 << (bits - 1))) {
                    return NaN;
                }
                bits++;
            }
            let shift = 1, x1 = 0, x2 = 1;
            while (bits < 32) {
                let mask = 0xffffffff >>> (32 - bits);
                if ((Math.imul(x1, y & mask) & mask) == (z & mask)) {
                    x2 = x1 | (1 << shift);
                }
                else {
                    x1 = x2;
                    x2 = x2 | (1 << shift);
                }
                bits++;
                shift++;
            }
            return ((Math.imul(x1, y) >>> 0) == z)
                ? x1 >>> 0
                : x2 >>> 0;
        },
    
        blakkwater2: (y, z) => {
            y >>>= 0;
            z >>>= 0;
            if (y == 0) {
                return (z == 0)? 0: NaN;
            }
            let ybit = Math.clz32(y & -y) ^ 31;
            if (ybit && (z << (32 - ybit))) {
                return NaN;
            }
            let xbit = 0, xmask = 0, x = 0, zacc = 0;
            let ymask = 0xffffffff >>> (31 - ybit);
            while (ybit < 32) {
                if (y & (1 << ybit)) {
                    zacc = (zacc + ((x & xmask) << ybit)) >>> 0;
                }
                if ((zacc & ymask) != (z & ymask)) {
                    zacc = (zacc + ((y & ymask) << xbit)) >>> 0;
                    x |= 1 << xbit;
                }
                xmask |= 1 << xbit;
                ymask |= 2 << ybit;
                xbit++;
                ybit++;
            }
            return x >>> 0;
        },
    
        harold: (() => {
            function inverse (d) {
                var x = Math.imul(d, d) + d - 1;
                x = Math.imul(x, 2 - Math.imul(x, d));
                x = Math.imul(x, 2 - Math.imul(x, d));
                x = Math.imul(x, 2 - Math.imul(x, d));
                return x >>> 0;
            }
            return (y, z) => {
                y >>>= 0;
                z >>>= 0;
                let z_ctz = Math.clz32(z & -z) ^ 31;
                let y_ctz = Math.clz32(y & -y) ^ 31;
                if (y_ctz > z_ctz)
                    return NaN;
                return Math.imul(z >>> y_ctz, inverse(y >>> y_ctz)) >>> 0;
            };
        })()
    };
    
    class UnitTest {
    
        #sets;
        #regularSamples;
    
        constructor (samplesPerSet) {
    
            samplesPerSet = Math.max(Math.floor(samplesPerSet), 1);
            let sets = [];
            let regularSamples = 0;
    
            for (let tz = 0; tz <= 32; tz++) {
                let set = UnitTest.createSet(tz, samplesPerSet);
                sets.push(set);
                regularSamples += set.sampleCount;
            }
    
            this.#sets = sets;
            this.#regularSamples = regularSamples;
        }
    
        test (func) {
    
            let sets = this.#sets;
            let regularPassed = 0;
            let regularFailed = 0;
            let regularCrashed = 0;
            let timeStart = performance.now();
    
            for (let tz = 0; tz <= 32; tz++) {
                let samples = sets[tz].start();
                for (let [y, z] of samples) {
                    try {
                        let x = func(y, z);
                        let zcheck = Math.imul(x, y) >>> 0;
                        if (zcheck === z) {
                            regularPassed++;
                        }
                        else {
                            regularFailed++;
                        }
                    }
                    catch (e) {
                        regularCrashed++;
                    }
                }
            }
    
            let time = performance.now() - timeStart;
            let unsolvablePassed = 0;
            let unsolvableFailed = 0;
            let unsolvableCrashed = 0;
    
            let samples = UnitTest.unsolvableSet.start();
            for (let [y, z] of samples) {
                try {
                    if (Number.isNaN(func(y, z))) {
                        unsolvablePassed++;
                    }
                    else {
                        unsolvableFailed++;
                    }
                }
                catch (e) {
                    unsolvableCrashed++;
                }
            }
    
            return { regularPassed, regularFailed, regularCrashed, regularTime: time, unsolvablePassed, unsolvableFailed, unsolvableCrashed };
        }
    
        get regularSamples () { return this.#regularSamples; }
    
        static createSet (tz, sampleCount) {
    
            let maxSize = UnitTest.getMaxSetSize(tz);
            return (
                (maxSize <= sampleCount)
                    ? UnitTest.createFullSet(tz):
                (sampleCount > maxSize / 2)
                    ? UnitTest.createExclusiveSet(tz, sampleCount)
                    : UnitTest.createInclusiveSet(tz, sampleCount));
        }
    
        static createInclusiveSet (tz, sampleCount) {
    
            let samples = UnitTest.generateRandomizedSet(tz, sampleCount);
            return {
                sampleCount,
                start: function * () {
                    for (let s of samples) {
                        yield [Number(s >> 32n), Number(s & 0xffffffffn)];
                    }
                }
            };
        }
    
        static createExclusiveSet (tz, sampleCount) {
    
            let sampleCountSkip = UnitTest.getMaxSetSize(tz) - sampleCount;
            let samples = new BigUint64Array(sampleCountSkip + 1), si = 0;
            for (let s of UnitTest.generateRandomizedSet(tz, sampleCountSkip)) {
                samples[si++] = s;
            }
            samples[si] = 2n ** 64n - 1n;
            samples.sort();
            return {
                sampleCount,
                start: function * () {
                    let step = (1 << tz) >>> 0;
                    let si = 1;
                    let yskip = Number(samples[0] >> 32n);
                    let zskip = Number(samples[0] & 0xffffffffn);
                    for (let y = step; y < 0x100000000; y += step * 2) {
                        for (let z = 0; z < 0x100000000; z += step) {
                            if (y != yskip || z != zskip) {
                                yield [y, z];
                            }
                            else {
                                yskip = Number(samples[si] >> 32n);
                                zskip = Number(samples[si] & 0xffffffffn);
                                si++;
                            }
                        }
                    }
                }
            };
        }
    
        static createFullSet (tz) {
    
            return {
                sampleCount: UnitTest.getMaxSetSize(tz),
                start: function * () {
                    if (tz == 32) {
                        yield [0, 0];
                        return;
                    }
                    let step = (1 << tz) >>> 0;
                    for (let y = step; y < 0x100000000; y += step * 2) {
                        for (let z = 0; z < 0x100000000; z += step) {
                            yield [y, z];
                        }
                    }
                }
            };
        }
    
        static generateRandomizedSet (tz, sampleCount) {
    
            let samples = new Set();
            let mask = 0xffffffff << tz;
            let ybit1 = 1 << tz;
            for (let si = 0; si < sampleCount; si++) {
                do {
                    let y = (Math.random() * 0x100000000 & mask | ybit1) >>> 0;
                    let z = (Math.random() * 0x100000000 & mask) >>> 0;
                    var s = (BigInt(y) << 32n) | BigInt(z);
                }
                while (samples.has(s));
                samples.add(s);
            }
            return samples;
        }
    
        static getMaxSetSize (tz) {
    
            return (tz < 32)? 2 ** ((32 - tz) * 2 - 1): 1;
        }
    
        static unsolvableSet = {
            sampleCount: (1 + 32) * 32 / 2,
            start: function * () {
                for (let y = 2; y <= 0x100000000; y *= 2) {
                    for (let z = 1; z < y; z *= 2) {
                        yield [y >>> 0, z];
                    }
                }
            }
        };
    }
    
    function toMetricNotation (value, precision) {
        let prefix = '';
        switch (true) {
        case (value < 1e3): break;
        case (value < 1e6): prefix = ' K'; value /= 1e3; break;
        case (value < 1e9): prefix = ' M'; value /= 1e6; break;
        default:            prefix = ' G'; value /= 1e9;
        }
        return value.toPrecision(precision) + prefix;
    }
    
    let ut = new UnitTest(20000);
    for (let [name, func] of Object.entries(solutions)) {
        let result = ut.test(func);
        let callsPerSecond = ut.regularSamples * 1000 / result.regularTime;
        console.log(`Test function: ${name}`);
        console.log(`- Unsolvable set:` +
            `\n  - Passed: ${result.unsolvablePassed}` +
            `\n  - Failed: ${result.unsolvableFailed}` +
            `\n  - Crashed: ${result.unsolvableCrashed}` +
            `\n  - Total: ${UnitTest.unsolvableSet.sampleCount}`);
        console.log(`- Regular tests:` +
            `\n  - Passed: ${result.regularPassed}` +
            `\n  - Failed: ${result.regularFailed}` +
            `\n  - Crashed: ${result.regularCrashed}` +
            `\n  - Total: ${ut.regularSamples}` +
            `\n  - Time: ${Math.floor(result.regularTime)} ms` +
            `\n  - Performance: ${toMetricNotation(callsPerSecond, 4)} calls per second`);
    }