# Randomly selecting from weighted samples

Recently at an interview I was given a small coding question to write a function which would return values with differing probability. The setup was this:

```
data = {'A': 8, 'B': 2, 'C': 5}
def random_sample(data):
```

Depending on the real dictionary, more than one solution is possible. The option I ultimately ended up with was

```
import random
def random_sample(data): # Version 0
data = {k: v/sum(data.values()) for k, v in data.items()}
cdf = 0.
rv = random.random()
for k, v in data.items():
cdf += v
if rv < cdf:
return k
```

I also mentioned that it would be faster to not call `sum(data.values())`

so many times. This would lead to, eg,

```
import random
def random_sample(data): # Version 1
cdf = 0.
rv = random.random()
n = sum(data.values())
for k, v in data.items():
cdf += v / n
if rv < cdf:
return k
```

The interviewer said that their followup question, had we had time, would be
whether the less-than should be strict less than or not. Based on this,
four things come to mind. One is that trying to do calculations with `float`

and comparisons with `<`

vs `<=`

typically are going to be limited by floating
point precision errors so the notion that we can have strictly less than is itself
not exactly realistic. Second is that it’s possible in the above code that
`cdf`

never reaches `1.0`

at all and there would be a correspondingly small chance
that `rv < cdf`

is always `False`

in which case this function would return `None`

.
Upon reflection this is rather interesting as the interviewer said this was the
exact correct answer, yet clearly it’s potentially bugged. The solution to that is
at least trivial: add `return k`

after the `for`

loop, thanks to Python’s scope rules.
Third is that if we consider integers, then we could have solved this without using
probabilities at all:

```
import random
def random_sample(data): # Version 2
n = sum(data.values())
rv = random.randint(0, n - 1)
cdf = 0
for k, v in data.items():
cdf += v
if rv < cdf:
return k
```

Here, the rounding error is impossible so there’s no need for a `return k`

at the
end as this function would always return a key and should only return `None`

if the
empty dictionary was passed to the function. Thinking about the test data, however,
suppose that `'A'`

is the first key, and on the first iteration we then have `cdf == 8`

meaning we want to return `'A'`

for `8`

possible values of `rv`

. Since `rv`

is chosen from
the range `[0, n)`

, this corresponds to the values `[0, 7]`

and thus we do want strictly
less than.

Afterwards, a one line solution came to mind and, at the cost of memory, it completely
sidesteps the potential of rounding error resulting in a return value of `None`

or the
question of strictly less than. Instead, we can exploit `random.choice`

and list
comprehension:

```
from random import choice
def random_sample(data): # Version 3
return choice([k for k, v in data.items() for _ in range(v)])
```

Going a step further we can even enlist some help from the `collections`

module:

```
from random import choice
from collections import Counter
def random_samples(data): # Version 4
return choice(list(Counter(data).elements()))
```

In my opinion, this code is the most expressive though we have to acknowledge the
memory (and therefore computational) cost here. The memory usage scales linearly
with `sum(data.values())`

. The method `collections.Counter.elements`

actually
returns an `itertools.chain`

which cannot be used in a `len`

expression that’s
needed by `random.choice`

, explaining the conversion to a `list`

. This technically
results in potentially unbound memory usage (eg, `{'A': 10**9}`

should result in a
`list`

that consumes many gigabytes of memory).

Now for a reality check: performance. I tested this with 100k iterations for
each version and the dictionary `{'A': 10, 'B': 3, 'C': 4}`

on Python 3.5.2. The
results are:

Version | Time (s) |
---|---|

0 | 0.366 |

1 | 0.157 |

2 | 0.450 |

3 | 0.621 |

4 | 1.416 |

In version 0, it’s either the sum or the dictionary comprehension which is causing
slower performance relative to version 1. Moving the summation out but retaining
the dictionary comprehension, performance is instead 0.268s meaning roughly half
of the performance difference is the repeated calls to `sum`

and half is the
comprehension itself. Timing `random.random()`

versus `random.randint(0, 16)`

reveals that a large difference between generating random floating point numbers
and random integers exists, yielding the decrease in performance seen in version 2.
Version 3 is a little trickier to explain, but there is an overhead in creating
the `list`

from the comprehension. Timing the list comprehension in version 3 vs
the call to `list(Counter(data).elements())`

I find that the former takes 0.40s while
the latter takes 1.16s.

Cranking up the number of items with `{'A': 350, 'B': 30, 'C': 140}`

I found that
versions 3 and 4 start to greatly underperform the other solutions while also switching
their relative ordering:

Version | Time (s) |
---|---|

0 | 0.284 |

1 | 0.158 |

2 | 0.465 |

3 | 2.548 |

4 | 2.320 |

Do note that here version 0 has been updated to move the summation out, so the comparable
entry from the previous table is the 0.268s discussed there, indicating that all
solutions are at least somewhat dependent upon the number of items in the collection,
with the issue clearly affecting versions 3 and 4 the most. Adding more keys to the
input dictionary while preserving the total number should impact only version 0 as it
needs to store more keys. Versions 3 and 4 wouldn’t be affected by this as they depend
only upon `sum(data.values())`

, and versions 1 and 2 clearly don’t create any new
containers to begin with.