Search code examples
c++bit-manipulationstockfishswar

Multiplication of two packed signed integers in one


The Stockfish chess engine needs to store, for its evaluation, both an endgame score and a middlegame score.

Instead of storing them separately, it packs them into one int. The middlegame score is stored in the lower 16 bits. The endgame score is stored in the higher 16 bits, as-is if the middlegame score is positive or minus one if it is negative.

This has the advantage that operations (addition, subtraction, negation and multiplication) can be done for both numbers in parallel.

Here is the code:

/// Score enum stores a middlegame and an endgame value in a single integer (enum).
/// The least significant 16 bits are used to store the middlegame value and the
/// upper 16 bits are used to store the endgame value. We have to take care to
/// avoid left-shifting a signed int to avoid undefined behavior.
enum Score : int { SCORE_ZERO };

constexpr Score make_score(int mg, int eg) {
  return Score((int)((unsigned int)eg << 16) + mg);
}

/// Extracting the signed lower and upper 16 bits is not so trivial because
/// according to the standard a simple cast to short is implementation defined
/// and so is a right shift of a signed integer.
inline Value eg_value(Score s) {
  union { uint16_t u; int16_t s; } eg = { uint16_t(unsigned(s + 0x8000) >> 16) };
  return Value(eg.s);
}

inline Value mg_value(Score s) {
  union { uint16_t u; int16_t s; } mg = { uint16_t(unsigned(s)) };
  return Value(mg.s);
}

#define ENABLE_BASE_OPERATORS_ON(T)                                \
constexpr T operator+(T d1, int d2) { return T(int(d1) + d2); }    \
constexpr T operator-(T d1, int d2) { return T(int(d1) - d2); }    \
constexpr T operator-(T d) { return T(-int(d)); }                  \
inline T& operator+=(T& d1, int d2) { return d1 = d1 + d2; }       \
inline T& operator-=(T& d1, int d2) { return d1 = d1 - d2; }

ENABLE_BASE_OPERATORS_ON(Score)

/// Only declared but not defined. We don't want to multiply two scores due to
/// a very high risk of overflow. So user should explicitly convert to integer.
Score operator*(Score, Score) = delete;

/// Division of a Score must be handled separately for each term
inline Score operator/(Score s, int i) {
  return make_score(mg_value(s) / i, eg_value(s) / i);
}

/// Multiplication of a Score by an integer. We check for overflow in debug mode.
inline Score operator*(Score s, int i) {

  Score result = Score(int(s) * i);

  assert(eg_value(result) == (i * eg_value(s)));
  assert(mg_value(result) == (i * mg_value(s)));
  assert((i == 0) || (result / i) == s);

  return result;
}

I understand how addition, subtraction and negation work, but what I have trouble understanding is multiplication. How does multiplying the integer multiplies both the endgame and the middlegame scores together correctly?


Solution

  • Here's a proof, assuming addition works which you already understand.

    It involves some algebra, so here's some handy notation for that. All variables are signed 16-bit values, 2's complement representation, unless stated otherwise.

    S(eg,mg) = [eg - neg(mg), mg]

    Here eg is an endgame score, and mg is a midgame score and [,] shows the bits of values packed into a 32-bit 2's complement signed integer. Most significant bits are to the left. S(eg,mg) is the representation of those scores. Here neg(x) is 1 if x is negative and 0 otherwise, so that the requirement that the representation of eg in the upper bits has one subtracted if mg is negative is satisfied.

    This is what make_score should do (bug — for example it is not consistent with eg_value for negative mg).

    We want to show that for any 32-bit signed integer i, multiplying S(eg,mg) by i just models multiplying eg and mg by i, which can be precisely stated as

    (T) i*S(eg,mg) = S(i*eg, i*mg) where 16-bit multiplications are used in the arguments.

    Proof: If i=0 this is obvious. Suppose that i>0 and that addition works, i.e. adding representations adds their endgame scores and adds their midgame scores. i.e. suppose

    (A) S(eg1,mg1) + S(eg2,mg2) = S(eg1+eg2,mg1+mg2) for any arguments.

    Then repeated use of (A) establishes (T) as follows.

    i*S(eg,mg) 
            // sum of i copies of S(eq,mq)
        = S(eg,mg) + S(eg,mg) + S(eg,mg) + ... + S(eg,mg)
            // (A) on first two terms
        = S(eg+eg,mg+mg) + S(eg,mg) + ... + S(eg,mg)
        = S(2*eg,2*mg) + S(eg,mg) + ... + S(eg,mg)
            // (A) on first two remaining terms
        = S(2*eg+eg,2*mg+mg) + ... + S(eg,mg)
        = S(3*eg,3*mg) + ... + S(eg,mg)
            // ...
        = S(i*eg,i*mg)
    
    

    So now (T) is established for all but negative i values.

    Suppose i < 0, and let i = -j. Then

    i*S(eg,mg) 
        = (-1)*j*S(eg,mg)
            // j is positive
        = (-1)*S(j*eg,j*mg)
           // if (T) works for (-1)
        = S((-1)*j*eg,(-1)*j*mg)
        = S((-j)*eg,(-j)*mg)
        = S(i*eg,i*mg)
    

    so (T) is true for negative i provided it is true for i = -1, which is proved below.

    This is the crux of the matter: working for scaling by (-1), i.e. proving that -S(eg,mg) = S(-eg,-mg). The proof below completes the proof of (T), solving the problem. Here's the argument:

    -S(eq,mq) = -[eg - neg(mg), mg]
    

    and we want to show this equal to

    S(-eg,-mg) = [(-eg) - neg(-mg), (-mg)]
    

    and here's some algebra that does just that. It makes liberal use of the well known 2's complement identity that -a = ~a + 1, i.e. the arithmetic negation is one more than the bitwise negation.

    -[eg - neg(mg), mg]
                // `-a = ~a + 1`
            = ~[eg - neg(mg), mg] + 1
                // ~ is bitwise
            = [~(eg - neg(mg)), ~mg] + 1
                // if mg=0 the carry on adding 1 will propagate to the upper bits
    case1: mg=0
    [~(eg - neg(mg)), ~mg] + 1
               // neg(mg) = 0
            = [~eg, ~mg] + 1
                // ~mg is a bit pattern of all ones
            = [~eg + 1, ~mg + 1]
                // `~a = -a - 1`
            = [-eg, -mg]
                // neg(-mg)=0 because mg=0
            = [(-eg) - neg(-mg), (-mg)] 
            = S(-eg,-mg) as desired.
    case2: mg≠0
    [~(eg - neg(mg)), ~mg] + 1
                // carry does not propagate to upper bits
            = [~(eg - neg(mg)), ~mg + 1]
                // `~a = -a - 1`
            = [-(eg - neg(mg)) - 1, -mg]
            = [-eg - (1 - neg(mg)), -mg]
                // neg(-mg) = 1 - neg(mg) for mg≠0
            = [(-eg) - neg(-mg), (-mg)]
            = S(-eg,-mg) as desired.