Search code examples
z3smtz3pysbv

Implementing the x86 PDEP/PEXT instructions efficiently in SMTlib


Is there a way to specify the PDEP/PEXT instructions efficiently in SMTlib bitvector syntax?

My best attempt for PEXT ends up with something to the tune of: "Iff bit N in mask is set, then bit count_bits(mask[0..N]) in the result is equal to bit N in the input value.". But this requires a way to count bits, which isn't available in QF_BV.

Similarly for PDEP I end up with something like: "Iff bit N in mask is set, then bit N in the result is equal to bit count_bits(mask[0..N]) in the input value." Which again, requires counting bits.

I could write a naive bit counting function, but that would probably require evaluating all bits in the bitvector one-by-one. The entire specification would end up being O(N^2), which shouldn't be necessary.

Intel gives this implementation for PDEP:

TEMP := SRC1;
MASK := SRC2;
DEST := 0 ;
m := 0, k := 0;
DO WHILE m < OperandSize
    IF MASK[ m] = 1 THEN
        DEST[ m] := TEMP[ k];
        k := k+ 1;
    FI
    m := m+ 1;
OD

...but I am having trouble translating the while loop to SMTlib syntax.


Solution

  • When programming symbolically, you can't really avoid this sort of complexity. Note that your implementation must work correctly for all symbolic values of the input and the mask: Without having a concrete value for these, it's pointless to optimize: The solver will have to consider all possible paths, and hence optimization tricks don't really work. (You can think of symbolic simulation as "running for all possible inputs," so no shortcuts apply.)

    So, instead of focussing on the complexity of the loop, simply write the necessary constraints on the final value. Unfortunately SMTLib isn't terribly a good choice for this sort of programming, as it doesn't allow for many usual programming idioms. But that's precisely why we have programmable API's built on top of solvers. Z3 can be scripted from many languages, including C/C++/Python/Haskell/Java.. etc.; so I'd recommend using a language you're most familiar with and learning the bindings there.

    Based on this, I'll show how to code PEXT using both the popular Python API, and also the Haskell version. At the end, I'll also show how to do this in SMTLib by generating the code, which might come in handy. Note that I haven't extensively tested these, so while I trust the basic implementation is correct, you should double check before using it for anything serious.

    Python

    from z3 import *
    
    # Implements bv[idx] = b, where idx is concrete. Assumes 1 <= size
    def updateBit(size, bv, idx, b):
        if idx == 0:
            return Concat(Extract(size-1, idx+1, bv), b)
        elif idx == size-1:
            return Concat(                            b, Extract(idx-1, 0, bv))
        else:
            return Concat(Extract(size-1, idx+1, bv), b, Extract(idx-1, 0, bv))
    
    # Implements: bv[idx] = b, where idx can be symbolic. Assumes 1 <= size <= 2^8
    def Update(size, bv, idx, b):
        for i in range(size):
            bv = If(BitVecVal(i, 8) == idx, updateBit(size, bv, i, b), bv)
        return bv
    
    # Implements PEXT, by returning the resulting bit-vector. Assumes 1 <= size <= 2^8
    def PEXT(size, src, mask):
        dest = BitVec('dest', size)
        idx  = BitVecVal(0, 8)
    
        for m in [Extract(i, i, mask) for i in range(size)]:
            srcBit = Extract(0, 0, src)
            src    = LShR(src, 1)
            dest   = If(m == BitVecVal(1, 1), Update(size, dest, idx, srcBit), dest)
            idx    = If(m == BitVecVal(1, 1), idx+1, idx)
    
        return dest
    
    # Simple example
    x, m, r = BitVecs('x m r', 8)
    s = Solver()
    s.add(x == BitVecVal(0xAA, 8))
    s.add(m == BitVecVal(0xAA, 8))
    s.add(r == PEXT(8, x, m))
    print(s.check())
    
    mdl = s.model()
    def grab(v):
        return hex(mdl[v].as_long())
    
    print("PEXT(8, " + grab(x) + ", " + grab(m) + ") = " + grab(r))
    

    When I run this, I get:

    sat
    PEXT(8, 0xaa, 0xaa) = 0xf
    

    Of course, this is just one example; so please convince yourself by studying the implementation in detail.

    Haskell

    For Haskell, I'll use the SBV library (https://hackage.haskell.org/package/sbv). Compared to Python, it's much more expressive and a lot easier to use; if you're familiar with Haskell:

    {-# LANGUAGE DataKinds #-}
    
    import Data.SBV
    import GHC.TypeLits (KnownNat)
    
    pext :: forall n. (KnownNat n, BVIsNonZero n) => SWord n -> SWord n -> SWord n
    pext src mask = walk 0 src 0 (blastLE mask)
      where walk dest _ _   []     = dest
            walk dest x idx (m:ms) = walk (ite m (sSetBitTo dest idx (lsb x)) dest)
                                          (x `shiftR` 1)
                                          (ite m (idx + 1) idx)
                                          ms
    

    Which can be tested by:

    test :: IO SatResult
    test = satWith z3{printBase = 16} $
                 do x :: SWord 8 <- sWord "x"
                    m :: SWord 8 <- sWord "m"
                    r :: SWord 8 <- sWord "r"
    
                    constrain $ x .== 0xAA
                    constrain $ m .== 0xAA
                    constrain $ r .== pext @8 x m
    

    which prints:

    *Main> test
    Satisfiable. Model:
      x = 0xaa :: Word8
      m = 0xaa :: Word8
      r = 0x0f :: Word8
    

    And here's pdep for completeness:

    pdep :: forall n. (KnownNat n, BVIsNonZero n) => SWord n -> SWord n -> SWord n
    pdep src mask = walk 0 src 0 (blastLE mask)
      where walk dest _ _   []     = dest
            walk dest x idx (m:ms) = walk (ite m (sSetBitTo dest idx (lsb x)) dest)
                                          (ite m (x `shiftR` 1) x)
                                          (idx + 1)
                                          ms
    

    Hopefully you can adopt these ideas, and implement them in your language of choice in a similar fashion. Again, please test further before using it for anything serious though, I only tested these for just the one value I shown above.

    SMTLib

    As I mentioned, SMTLib isn't really suitable for this sort of programming, since it lacks usual programming language idioms. But if you have to have this coded in SMTLib, then you can use the sbv library in Haskell to generate the function for you. Note that SMTLib is a "monomorphic" language, that is, you can't have a definition of PEXT that works for arbitrary bit-sizes. You'll have to generate one for each size you need. The generated code will not be pretty to look at, as it will be a full unfolding, but I trust it should perform well. Here's the output for the 2-bit case. (As you can imagine, longer bit sizes will yield bigger functions).

    *Main> sbv2smt (smtFunction "pext_2" (pext @2)) >>= putStrLn
    ; Automatically generated by SBV. Do not modify!
    ; pext_2 :: SWord 2 -> SWord 2 -> SWord 2
    (define-fun pext_2 ((l1_s0 (_ BitVec 2)) (l1_s1 (_ BitVec 2))) (_ BitVec 2)
      (let ((l1_s3 #b0))
      (let ((l1_s7 #b01))
      (let ((l1_s8 #b00))
      (let ((l1_s20 #b10))
      (let ((l1_s2 ((_ extract 1 1) l1_s1)))
      (let ((l1_s4 (distinct l1_s2 l1_s3)))
      (let ((l1_s5 ((_ extract 0 0) l1_s1)))
      (let ((l1_s6 (distinct l1_s3 l1_s5)))
      (let ((l1_s9 (ite l1_s6 l1_s7 l1_s8)))
      (let ((l1_s10 (= l1_s7 l1_s9)))
      (let ((l1_s11 (bvlshr l1_s0 l1_s7)))
      (let ((l1_s12 ((_ extract 0 0) l1_s11)))
      (let ((l1_s13 (distinct l1_s3 l1_s12)))
      (let ((l1_s14 (= l1_s8 l1_s9)))
      (let ((l1_s15 ((_ extract 0 0) l1_s0)))
      (let ((l1_s16 (distinct l1_s3 l1_s15)))
      (let ((l1_s17 (ite l1_s16 l1_s7 l1_s8)))
      (let ((l1_s18 (ite l1_s6 l1_s17 l1_s8)))
      (let ((l1_s19 (bvor l1_s7 l1_s18)))
      (let ((l1_s21 (bvand l1_s18 l1_s20)))
      (let ((l1_s22 (ite l1_s13 l1_s19 l1_s21)))
      (let ((l1_s23 (ite l1_s14 l1_s22 l1_s18)))
      (let ((l1_s24 (bvor l1_s20 l1_s23)))
      (let ((l1_s25 (bvand l1_s7 l1_s23)))
      (let ((l1_s26 (ite l1_s13 l1_s24 l1_s25)))
      (let ((l1_s27 (ite l1_s10 l1_s26 l1_s23)))
      (let ((l1_s28 (ite l1_s4 l1_s27 l1_s18)))
      l1_s28))))))))))))))))))))))))))))
    

    You can drop the above (or at whatever size you generate it at) in your SMTLib program and use as is, if you really must stick to SMTLib.