0

I was going through the BERT repo and found the following piece of code:

for _ in range(10):
    random_document_index = rng.randint(0, len(all_documents) - 1)
    if random_document_index != document_index:
        break

The idea here being to generate a random integer on [0, len(all_documents)-1] that cannot equal document_index. Because len(all_documents) is suppose to be a very large number, the first iteration is almost guaranteed to produce a valid randint, but just to be safe, they try it for 10 iterations. I can't help but think there has to be a better way to do this.

I found this answer which is easy enough to implement in python:

random_document_index = rng.randint(0, len(all_documents) - 2)
random_document_index += 1 if random_document_index >= document_index else 0

I was just wondering if there's a better way to achieve this in python using the in-built functions (or even with numpy), or if this is the best you can do.

Jay Mody
  • 3,727
  • 1
  • 11
  • 27
  • Looks like the solution is not working if you need to draw from `[0, 1, 3, 4, 6, 7]`. Is this correct and a situation you would be interested in? – norok2 Dec 25 '19 at 20:52
  • @norok2 No, in this case I'm only looking to exclude one number, so this wouldn't be a case I'd be interested in. – Jay Mody Dec 25 '19 at 21:00

2 Answers2

2

Had len(all_documents) been small, a pretty solution would be to realize all valid numbers (e.g. in a list) and use random.choice(). Since your len(all_documents) is supposedly large, this solution will waste a lot of memory.

A more memory efficient solution is to stick with the original strategy. It's really very reasonable for large len(all_documents) where a single iteration is very likely to be enough, though the hard-coded 10 is ugly. A pretty one-line solution would be to make use of the new walrus operator in Python 3.8:

while (random_document_index := rng.randint(0, len(all_documents) - 1)) == document_index: pass
jmd_dk
  • 12,125
  • 9
  • 63
  • 94
1

Perhaps a more elegant way of picking integers with holes is to use random.choice():

import random


seq = [0, 1, 3, 4, 6, 7]
random.choice(seq)

the drawback is that it requires a sequence, which, in the case of a simple list may not be efficient in your case, and it is generally not efficient if the size of the range is much larger than the number of invalid values. In that case, a more efficient approach would be to create custom generating sequence with the knowledge of the "holes" only.


EDIT

Such implementation would take the form of a non-contiguous range (without step support) with invalid numbers, implementing the Sequence interface:

class NonContRange(object):
    def __init__(self, start, stop, invalid=None):
        self.start = start
        self.stop = stop
        self.invalid = invalid if invalid else set()

    def __len__(self):
        return self.stop - self.start - len(self.invalid)

    def __getitem__(self, i):
        offset = 0
        for invalid in sorted(self.invalid):
            if invalid <= self.start + i + offset:
                offset += 1
        return self.start + i + offset

    def __iter__(self):
        for i in range(self.start, self.stop):
            if i not in self.invalid:
                yield i

    def __reversed__(self):
        for i in range(self.stop - 1, self.start - 1, -1):
            if i not in self.invalid:
                yield i

    def index(self, x):
        if x in self.invalid:
            raise ValueError(f'{x} not in sequence.')
        else:
            offset = sum(1 for y in self.invalid if y < x)
            return x - self.start - offset

    def count(self, x):
        return 0 if x in self.invalid or not (self.start <= x < self.stop) else 1

    def __str__(self):
        return f'NonContRange({self.start}, {self.stop}, ­­­¬{sorted(self.invalid)})'

A few tests:

seq = NonContRange(10, 20, {12, 15, 16})
print(seq)
# NonContRange(10, 20, ­­­¬[12, 15, 16])
print(list(seq))
# [10, 11, 13, 14, 17, 18, 19]
print(list(reversed(seq)))
# [19, 18, 17, 14, 13, 11, 10]
print([seq[i] for i in range(len(seq))])
# [10, 11, 13, 14, 17, 18, 19]
print(list(seq).count(19))
# 1
print(list(seq).count(12))
# 0

and this can be safely used with random.choice():

import random


invalid = {12, 17}
seq = NonContRange(10, 20, invalid)
print(all(random.choice(seq) not in invalid for _ in range(10000)))
# True

This is of course very nice in the general case, but for your specific situation it looks more like killing a fly with a cannonball.

norok2
  • 25,683
  • 4
  • 73
  • 99
  • In this case, the piece of code is nested in a loop, where `document_index` is a changing value. Removing it from a list, performing random.choice, then re-adding it to the list, would be less efficient. – Jay Mody Dec 25 '19 at 21:02
  • @JayMody you do not need a `list`, you just need a sequence. See the edits for an efficient non-contiguous range that would work well also in your use case, although it may be an overkill for the problem at hand. – norok2 Dec 25 '19 at 21:37