I have defined a function
def enumerateSpin(n):
s = []
for a in range(0,3**n):
ternary_rep = np.base_repr(a,3)
k = len(ternary_rep)
r = (n-k)*'0'+ternary_rep
if sum(map(int,r)) == n:
s.append(r)
return s
where I look at a number 0 <= a < 3^N and ask if the sum of its digits in the ternary representation sum up to a certain value. I do this by converting the number into a string of its ternary representation first. I am padding zeros because I want to store a list of fixed-length representations that I can later use for further computations (i.e. digit-by-digit comparison between two elements).
Right now np.base_repr
and sum(map(int,#))
take roughly 5 us on my computer respectively, meaning roughly 10 us for an iteration, and I am looking for an approach where you can accomplish what I did but 10 times faster.
(Edit: note about padding zeros on the left)
(Edit2: in hindsight, it is better to have the final representation be tuples of integers than strings).
(Edit3: for those wondering, the purpose of the code was to enumerate states of a spin-1 chain that have the same total S_z values.)
You can use itertools.product
to generate the digits and then convert to the string representation:
import itertools as it
def new(n):
s = []
for digits in it.product((0, 1, 2), repeat=n):
if sum(digits) == n:
s.append(''.join(str(x) for x in digits))
return s
This gives me about 7x speedup:
In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Tested on Python 3.9.0 (IPython 7.20.0) (Linux).
The above procedure, using it.product
, also generates numbers from which we know by reasoning that they don't obey the condition (this is the case for half of all numbers since the sum of digits must equal the number of digits). For n
digits, we can compute the various counts of digits 2
, 1
and 0
that eventually sum up to n
. Then we can generate all distinct permutations of these digits and thus only generate relevant numbers:
import itertools as it
from more_itertools import distinct_permutations
def new2(n):
all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
return (''.join(digits) for digits in all_digits)
Especially for large numbers of n
this gives an additional, significant speedup:
In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Note that the above solutions new
and new2
have O(1) memory scaling (change new
to yield
instead of append
).