2

I have defined a function

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        ternary_rep = np.base_repr(a,3)
        k = len(ternary_rep)
        r = (n-k)*'0'+ternary_rep
        if sum(map(int,r)) == n:
            s.append(r)
    return s

where I look at a number 0 <= a < 3^N and ask if the sum of its digits in the ternary representation sum up to a certain value. I do this by converting the number into a string of its ternary representation first. I am padding zeros because I want to store a list of fixed-length representations that I can later use for further computations (i.e. digit-by-digit comparison between two elements).

Right now np.base_repr and sum(map(int,#)) take roughly 5 us on my computer respectively, meaning roughly 10 us for an iteration, and I am looking for an approach where you can accomplish what I did but 10 times faster.

(Edit: note about padding zeros on the left)

(Edit2: in hindsight, it is better to have the final representation be tuples of integers than strings).

(Edit3: for those wondering, the purpose of the code was to enumerate states of a spin-1 chain that have the same total S_z values.)

wcc
  • 196
  • 8
  • First of all, why bother padding your ternary representation with `0`s? That string addition is pretty expensive. Secondly, why do you need something 10x faster? Third, why not just `sum(int(i) for i in np.base_repr(a,3))`. Fourth, how about some parallelization to improve code speed? – inspectorG4dget Feb 05 '21 at 21:37
  • What is your goal? Do you have to return ternary representation as a string? Can' you just return list of numbers ? – fukanchik Feb 05 '21 at 21:40
  • I am padding zeros and separating mapping/list comprehension from `np.base_repr` because I wanted to store a fixed-length representation of numbers into my list `s`. I just quoted a 10x improvement in speed because I felt what I have done is pretty inefficient surely something better exists. Parallelization is a good suggestion but I do not have experience in it. – wcc Feb 05 '21 at 21:40

4 Answers4

3

You can use itertools.product to generate the digits and then convert to the string representation:

import itertools as it

def new(n):
    s = []
    for digits in it.product((0, 1, 2), repeat=n):
        if sum(digits) == n:
            s.append(''.join(str(x) for x in digits))
    return s

This gives me about 7x speedup:

In [8]: %timeit enumerateSpin(12)
2.39 s ± 7.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [9]: %timeit new(12)
347 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Tested on Python 3.9.0 (IPython 7.20.0) (Linux).

The above procedure, using it.product, also generates numbers from which we know by reasoning that they don't obey the condition (this is the case for half of all numbers since the sum of digits must equal the number of digits). For n digits, we can compute the various counts of digits 2, 1 and 0 that eventually sum up to n. Then we can generate all distinct permutations of these digits and thus only generate relevant numbers:

import itertools as it
from more_itertools import distinct_permutations

def new2(n):
    all_digits = (('2',)*i + ('1',)*(n-2*i) + ('0',)*i for i in range(n//2+1))
    all_digits = it.chain.from_iterable(distinct_permutations(d) for d in all_digits)
    return (''.join(digits) for digits in all_digits)

Especially for large numbers of n this gives an additional, significant speedup:

In [44]: %timeit -r 1 -n 1 new(16)
31.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [45]: %timeit -r 1 -n 1 list(new2(16))
7.82 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Note that the above solutions new and new2 have O(1) memory scaling (change new to yield instead of append).

a_guest
  • 34,165
  • 12
  • 64
  • 118
  • this is nice and elegant. However, I am getting ~3x speedup, and wondering where the difference is. I get (0.827 vs 2.4) for N=12 and (2.57 vs 8.36) for N=13. I am using `'3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)]'` by the way. – wcc Feb 05 '21 at 22:24
  • @wcc It turned out that I was running the code accidentally in an Python 2.7 IPython shell. Now I repeated for Python 3.9 and I get about 7x speedup. – a_guest Feb 05 '21 at 22:38
  • thanks for the feedback. I got my numbers running `%%timeit` magic on a Jupyter notebook. Does running on Jupyter generally make things slower? – wcc Feb 05 '21 at 22:40
  • @wcc It shouldn't be much slower for Jupyter, since all it does is sending your query to the IPython kernel. Perhaps it's the Python version? By the way, I've added another solution that only generates relevant numbers, please see my updated answer. – a_guest Feb 05 '21 at 23:20
2

A 10x improvement can be achieved by delegating all calculations to numpy in order to leverage vectorized processing:

def eSpin(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums // (3**np.arange(n))[:,None] % 3
    matches = np.sum(base3,axis=0) == n
    digits  = np.sum(base3[:,matches] * 10**np.arange(n)[:,None],axis=0)
    return [f"{a:0{n}}" for a in digits]   

How it works (example for eSpin(3)):

nums is an array of all numbers up to 3**n

   [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26]  

base3 converts it into base 3 digits in an additional dimension:

[[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
 [0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2]
 [0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]]

matches identifies the columns where the sum of base3 digits matches n

 [0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 0]

digits converts the matching columns into a base 10 number formed of the base3 digits

 [ 12  21 102 111 120 201 210]

And finally the matching (base10) numbers are formatted with leading zeros.

performance:

from timeit import timeit
count = 1

print(enumerateSpin(10)==eSpin(10)) # True

t1 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t1) # 0.634 sec

t0 = timeit(lambda:enumerateSpin(13),number=count)
print("enumerateSpin",t0) # 7.362 sec

Tuple version:

def eSpin2(n):
    nums    = np.arange(3**n,dtype=np.int)
    base3   = nums// (3**np.arange(n))[:,None]  % 3
    matches = np.sum(base3,axis=0) == n
    return [*map(tuple,base3[:,matches].T)]

eSpin2(3)
[(2, 1, 0), (1, 2, 0), (2, 0, 1), (1, 1, 1), (0, 2, 1), (1, 0, 2), (0, 1, 2)]

[EDIT] An even faster approach (40x to 80x faster than enumerateSpin)

Using dynamic programming and memoization can provide a lot better performance:

@lru_cache()
def eSpin(n,base=3,target=None):
    if target is None: target = n
    if target == 0: return [(0,)*n]
    if target>base**n-1: return []
    if n==1: return [(target,)]
    result = []
    for d in range(min(base,target+1)):
        result.extend((d,)+suffix for suffix in eSpin(n-1,base,target-d) )
    return result

t4 = timeit(lambda:eSpin(13),number=count)
print("eSpin",t4) # 0.108 sec

eSpin.cache_clear()
t5 = timeit(lambda:eSpin(16),number=count)
print("eSpin",t5) # 2.25 sec
Alain T.
  • 40,517
  • 4
  • 31
  • 51
  • thanks. This is slightly faster than the answer by @a_guest. But I still favor the latter in that I can also easily choose to save the representations as tuples of integers instead of strings. – wcc Feb 05 '21 at 23:05
  • 1
    You could do `return [*map(tuple,base3[:,matches].T)]` instead of converting to `digits` if you want tuples, but that's up to you. – Alain T. Feb 05 '21 at 23:10
  • If I'm correct, `base3` has `n*3**n` elements? For larger numbers of `n` this will quickly consume huge amounts of memory. – a_guest Feb 05 '21 at 23:23
  • @a_guest, it is Indeed. I tested up to eSpin(16) without memory issue, the processing time is the main limiting factor. The generator approach is way better in that respect. The small performance gain with numpy will probably not be worth the resource consumption, so I also prefer your solution. – Alain T. Feb 05 '21 at 23:25
  • the improved method is pretty impressive in terms of speed. Thanks! – wcc Feb 06 '21 at 02:42
0

Here's a multiprocessing approach. It'll afford more time savings, the more the problem scales in size

import multiprocessing as mp


def filter(n, qIn, qOut):  # this is the function that will be parallelized
    nums = range(3**n)
    answer = []
    for low,high in iter(qIn.get, None):
        for num in nums[low:high]:
            r = np.base_repr(num, 3)  # ternary representation
            if sum(int(i) for i in r) == num:  # this is your check
                answer.append('0'*(n-len(r)) +r)  # turn it into a fixed length
    qOut.put(answer)
    qOut.put(None)


def enumerateSpin(n):  # this is the primary entry point
    numProcs = mp.cpu_count()-1  # fiddle to taste
    chunkSize = n//numProcs

    qIn, qOut = [mp.Queue() for _ in range(2)]
    procs = [mp.Process(target=filter, args=(n, qIn, qOut)) for _ in range(numProcs)]

    for p in procs: p.start()
    for i in range(0, 3**n, chunkSize):  # chunkify your numbers so that IPC is more efficient
        qIn.put((i, i+chunkSize))
    for p in procs: qIn.put(None)

    answer = []
    done = 0
    while done < len(procs):
        t = qOut.get()
        if t is None:
            done += 1
            continue
        answer.extend(t)

    for p in procs: p.terminate()

    return answer
inspectorG4dget
  • 110,290
  • 27
  • 149
  • 241
  • Thanks. I assume this has to be run outside of a Jupyter notebook (somehow the function does not terminate and I have to kill the kernel)? – wcc Feb 05 '21 at 22:49
  • @wcc: oops! I forgot to kill the processes. Please check the update :) – inspectorG4dget Feb 06 '21 at 21:23
0

In general, to get the digits of a number in a specific base we can do:

while num > 0:
    digit = num % base
    num //= base
    print(digit)

When running this with num = 14, base = 3 we get:

2
1
1

Which means that 14 in ternary is 112.
We can extract that into a method digits(num, base) and only use np.base_repr(a,3) when we actualy need to convert the number into a string:

def enumerateSpin(n):
    s = []
    for a in range(0,3**n):
        if sum(digits(a, 3)) == n:
            ternary_rep = np.base_repr(a,3)
            k = len(ternary_rep)
            r = (n-k)*'0'+ternary_rep
            s.append(r)
    return s

Output for enumerateSpin(4):

['0022', '0112', '0121', '0202', '0211', '0220', '1012', '1021', '1102', '1111', '1120', '1201', '1210', '2002', '2011', '2020', '2101', '2110', '2200']
Roy Cohen
  • 1,540
  • 1
  • 5
  • 22