Search code examples
setnim-lang

pop() for built-in sets


The HashSet in std/sets has a pop() function that removes and returns an arbitrary element from the set.

Is there an equivalent for built-in sets in Nim? If not, is there some other recommended way to get an arbitrary element from a set? It doesn't matter if it's destructive (like pop()) or not.


Solution

  • Ok, so this isn't the 'recommended' way to do this, obviously, but it does work, and it can even be worth the trouble if micro-optimising.

    Nim sets, under the hood, are either the smallest uintx that fits (up to T.high - T.low = 63), or an array of uint8.

    So the basic idea is:

    import std/bitops
    proc lowestElement[T](s:set[T]):T=
      if s.card == 0: raise newException(KeyError,"empty set")
    
      type SetT = (
        when sizeof(s)== 1: uint8
        elif sizeof(s)==2: uint16 
        elif sizeof(s)==4: uint32 
        elif sizeof(s)==8: uint64 
        else: array[sizeof(s),uint8])
    
      let theBits = cast[SetT](s)
    
      when sizeof(s)<=8:
        T(firstSetBit(theBits) + T.low.int - 1)
      else:
        var i = 0
        while theBits[i]==0: inc i
        T(firstSetBit(theBits[i]) + i*8 + T.low.int - 1)
    

    Ive used firstSetBit but countLeadingZeroBits could compile to fewer instructions, albeit gives you the largest member in the set. You'd need to swap the iteration on the array as well:

      when sizeof(s)<=8:
        T(sizeof(s)*8 - countLeadingZeroBits(theBits) + T.low.int - 1)
      else:
        var i = sizeof(s) - 1
        while theBits[i]==0: dec i
        T(8 - countLeadingZeroBits(theBits[i]) + i*8 + T.low.int - 1)
    

    You absolutely need to handle the empty set case somehow, and not just because of that infinite loop--those bit ops are undefined on 0. (ok, firstBitSet is defined on GCC, but only because it checks for zero for you)

    This should work with any ordinal type, including holey enums


    Ok below here i'm getting a bit silly, but i had to ask myself, is this actually faster, and can i vectorize it?

    and i'm afraid i did this (thanks to this question):

    import std/[bitops,options]
    import nimsimd
    
    template highestSingleImpl[T](s:set[T]) =
      type SetT = (
        when sizeof(s)== 1: uint8
        elif sizeof(s)==2: uint16
        elif sizeof(s)==4: uint32
        elif sizeof(s)==8: uint64
      )
      let theBits = cast[SetT](s)
      if theBits == 0:
        none(T)
      else:
        T(sizeof(s)*8 - countLeadingZeroBits(theBits) + T.low.int - 1).some
    
    template highestMultipleImpl[T](s:set[T]) =
      type
        HighT = ptr UncheckedArray[array[32,uint8]]
      when sizeof(s) mod 32 != 0:
        type
          LowT = ptr UncheckedArray[array[sizeof(s) mod 32,uint8]]
        let
          hiBits = cast[HighT](cast[LowT](s.unsafeAddr)[1].addr)
          loBits = cast[LowT](s.unsafeAddr)[0]
      else:
       let hiBits = cast[HighT](s.unsafeAddr)
    
      var i = sizeof(s) div 32 - 1
      while i >= 0:
        var
          vec = mm256_loadu_si256(hiBits[i].addr)
          nonzero_elem = mm256_cmpeq_epi8(vec, mm256_setzero_si256())
          mask = not mm256_movemask_epi8(nonzero_elem)
        if mask == 0:
          dec i
          continue
        let
          idx = 31 - countLeadingZeroBits(mask)
          highest_nonzero_byte = hiBits[i][idx]
        return T(i*32*8 + idx*8 + 8 - countLeadingZeroBits(highest_nonzero_byte) + T.low.int - 1 + 8 * (sizeof(s) mod 32)).some
    
      when sizeof(s) mod 32 != 0:
        i = (sizeof(s) mod 32) - 1
        while i >= 0:
          if loBits[i]==0:
            dec i
            continue
          return T(8 - countLeadingZeroBits(loBits[i]) + i*8 + T.low.int - 1).some
      return none(typedesc[T])
    
    
    
    proc highestElement[T](s:set[T]):Option[T]{.raises:[].}=
      when sizeof(s) <= 8:
        highestSingleImpl(s)
      else:
        highestMultipleImpl(s)
    
    

    i compared that to

    proc naiveHighest[T](s:set[T]):Option[T] =
      var i = T.high.int
      whle i >= T.low.int:
        if T(i) in s:
          return T(i).some
        dec i
      return none(T)
    

    with a few different versions of

    block:
      type X = range[8..int16.high]
      doAssert sizeof(set[X]) mod 32 == 31
      var x:set[X]
      for i in 0..100:
        x.incl(rand(X))
    
      timeIt "not evenly vectorizable"
        var h:Z
        for _ in 0..1000:
          x.incl(rand(Z))
          h = z.highestElement.unsafeGet
          z.excl(h)
        keep(h)
    

    and i got:

       min time    avg time  std dv   runs name
       0.061 ms    0.069 ms  ±0.004  x1000 mod 32 == 0, vectorized
      18.571 ms   22.570 ms  ±1.302   x219 mod 32 == 0, naive
       0.068 ms    0.072 ms  ±0.002  x1000 mod 32 == 31, vectorized
      18.892 ms   22.742 ms  ±0.485   x221 mod 32 == 31, naive
       0.791 ms    0.836 ms  ±0.016  x1000 mod 32 == 31, clz only
    

    so, there you go.