I am seeking to sample n random permutations of a list in Python.
This is my code:
obj = [ 5 8 9 ... 45718 45719 45720]
#type(obj) = numpy.ndarray
pairs = random.sample(list(permutations(obj,2)),k= 150)
Although the code does what I want it to, it causes memory issues. I sometimes receive the error Memory error
when running on CPU, and when running on GPU, my virtual machine crashes.
How can I make the code work in a more memory-efficient manner?
Building on Pablo Ruiz's excellent answer, I suggest wrapping his sampling solution into a generator function that yields unique permutations by keeping track of what it has already yielded:
import numpy as np
def unique_permutations(sequence, r, n):
"""Yield n unique permutations of r elements from sequence"""
seen = set()
while len(seen) < n:
# This line of code adapted from Pablo Ruiz's answer:
candidate_permutation = tuple(np.random.choice(sequence, r, replace=False))
if candidate_permutation not in seen:
seen.add(candidate_permutation)
yield candidate_permutation
obj = list(range(10))
for permutation in unique_permutations(obj, 2, 15):
# do something with the permutation
# Or, to save the result as a list:
pairs = list(unique_permutations(obj, 2, 15))
My assumption is that you are sampling a small subset of the very large number of possible permutations, in which case collisions will be rare enough that keeping a seen
set will not be expensive.
Warnings: this function is an infinite loop if you ask for more permutations than are possible given the inputs. It will also get increasingly slow an n
gets close to the number of possible permutations, since collisions will get increasingly frequent.
If I were to put this function in my code base, I would put a shield at the top that calculated the number of possible permutations and raised a ValueError
exception if n
exceeded that number, and maybe output a warning if n
exceeded one tenth that number, or something like that.