The HashSet
in std/sets
has a pop()
function that removes and returns an arbitrary element from the set.
Is there an equivalent for built-in sets in Nim? If not, is there some other recommended way to get an arbitrary element from a set? It doesn't matter if it's destructive (like pop()
) or not.
Ok, so this isn't the 'recommended' way to do this, obviously, but it does work, and it can even be worth the trouble if micro-optimising.
Nim sets, under the hood, are either the smallest uintx that fits (up to T.high - T.low = 63), or an array of uint8.
So the basic idea is:
import std/bitops
proc lowestElement[T](s:set[T]):T=
if s.card == 0: raise newException(KeyError,"empty set")
type SetT = (
when sizeof(s)== 1: uint8
elif sizeof(s)==2: uint16
elif sizeof(s)==4: uint32
elif sizeof(s)==8: uint64
else: array[sizeof(s),uint8])
let theBits = cast[SetT](s)
when sizeof(s)<=8:
T(firstSetBit(theBits) + T.low.int - 1)
else:
var i = 0
while theBits[i]==0: inc i
T(firstSetBit(theBits[i]) + i*8 + T.low.int - 1)
Ive used firstSetBit
but countLeadingZeroBits
could compile to fewer instructions, albeit gives you the largest member in the set. You'd need to swap the iteration on the array as well:
when sizeof(s)<=8:
T(sizeof(s)*8 - countLeadingZeroBits(theBits) + T.low.int - 1)
else:
var i = sizeof(s) - 1
while theBits[i]==0: dec i
T(8 - countLeadingZeroBits(theBits[i]) + i*8 + T.low.int - 1)
You absolutely need to handle the empty set case somehow, and not just because of that infinite loop--those bit ops are undefined on 0. (ok, firstBitSet is defined on GCC, but only because it checks for zero for you)
This should work with any ordinal type, including holey enums
Ok below here i'm getting a bit silly, but i had to ask myself, is this actually faster, and can i vectorize it?
and i'm afraid i did this (thanks to this question):
import std/[bitops,options]
import nimsimd
template highestSingleImpl[T](s:set[T]) =
type SetT = (
when sizeof(s)== 1: uint8
elif sizeof(s)==2: uint16
elif sizeof(s)==4: uint32
elif sizeof(s)==8: uint64
)
let theBits = cast[SetT](s)
if theBits == 0:
none(T)
else:
T(sizeof(s)*8 - countLeadingZeroBits(theBits) + T.low.int - 1).some
template highestMultipleImpl[T](s:set[T]) =
type
HighT = ptr UncheckedArray[array[32,uint8]]
when sizeof(s) mod 32 != 0:
type
LowT = ptr UncheckedArray[array[sizeof(s) mod 32,uint8]]
let
hiBits = cast[HighT](cast[LowT](s.unsafeAddr)[1].addr)
loBits = cast[LowT](s.unsafeAddr)[0]
else:
let hiBits = cast[HighT](s.unsafeAddr)
var i = sizeof(s) div 32 - 1
while i >= 0:
var
vec = mm256_loadu_si256(hiBits[i].addr)
nonzero_elem = mm256_cmpeq_epi8(vec, mm256_setzero_si256())
mask = not mm256_movemask_epi8(nonzero_elem)
if mask == 0:
dec i
continue
let
idx = 31 - countLeadingZeroBits(mask)
highest_nonzero_byte = hiBits[i][idx]
return T(i*32*8 + idx*8 + 8 - countLeadingZeroBits(highest_nonzero_byte) + T.low.int - 1 + 8 * (sizeof(s) mod 32)).some
when sizeof(s) mod 32 != 0:
i = (sizeof(s) mod 32) - 1
while i >= 0:
if loBits[i]==0:
dec i
continue
return T(8 - countLeadingZeroBits(loBits[i]) + i*8 + T.low.int - 1).some
return none(typedesc[T])
proc highestElement[T](s:set[T]):Option[T]{.raises:[].}=
when sizeof(s) <= 8:
highestSingleImpl(s)
else:
highestMultipleImpl(s)
i compared that to
proc naiveHighest[T](s:set[T]):Option[T] =
var i = T.high.int
whle i >= T.low.int:
if T(i) in s:
return T(i).some
dec i
return none(T)
with a few different versions of
block:
type X = range[8..int16.high]
doAssert sizeof(set[X]) mod 32 == 31
var x:set[X]
for i in 0..100:
x.incl(rand(X))
timeIt "not evenly vectorizable"
var h:Z
for _ in 0..1000:
x.incl(rand(Z))
h = z.highestElement.unsafeGet
z.excl(h)
keep(h)
and i got:
min time avg time std dv runs name
0.061 ms 0.069 ms ±0.004 x1000 mod 32 == 0, vectorized
18.571 ms 22.570 ms ±1.302 x219 mod 32 == 0, naive
0.068 ms 0.072 ms ±0.002 x1000 mod 32 == 31, vectorized
18.892 ms 22.742 ms ±0.485 x221 mod 32 == 31, naive
0.791 ms 0.836 ms ±0.016 x1000 mod 32 == 31, clz only
so, there you go.