A Few Observations About the Fisher-Yates Shuffle Algorithm

Shuffling an array into random order is a common task in machine learning algorithms. For example, if an array holds array-index values such as idxs = [0, 1, 2, 3, . . ] then shuffling the idxs array gives you a way to iterate through training data in random order.

The Python language NumPy library has a built-in numpy.random.shuffle() function. But there are times when you want to implement a custom shuffle() function, and some programming languages don’t have a built-in shuffle() function.

The usual algorithm to shuffle the contents is called the Fisher-Yates shuffle, or sometimes the Knuth shuffle. There are several different variations. The variation I prefer, implemented in Python is:

def shuffle_basic(arr, rndobj):
  # last i iteration not necessary
  n = len(arr)
  for i in range(n):  # 0 to n-1 inclusive
    j = rndobj.randint(i, n)  # i to n-1 inclusive
    tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp

I iterate forward, i ranging from 0 to n-1 inclusive. This is not the most common technique. The last iteration will always swap the value in the last cell with itself. Therefore, it’s more common to write for i in range(n-1). I don’t mind the extra iteration because 1.) if an array has just one cell, range(n-1) will be out-of-range for some programming languages (not Python), and 2.) the for-loop range and the randint() range are the same which has a pleasant symmetry.

Somewhat weirdly, the Wikipedia article on Fisher-Yates gives a basic algorithm that iterates backwards with the i variable in the outer loop. This baffles me — I can think of no reason to awkwardly iterate backwards when a forward iteration works just fine.

Additionally, I prefer to pass in a local random_object rather than use the global NumPy random_object. This makes reproducibility much easier because no other functions are modifying the random_object.

It’s well known that it’s easy to mess up the Fisher-Yates shuffle algorithm. Specifically:

def shuffle_bad(arr, rndobj):
  # last i iteration not necessary
  n = len(arr)
  for i in range(n):  # 0 to n-1 inclusive
    j = rndobj.randint(0, n)  # 0 to n-1 is wrong
    tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp

The call to the randint() function looks correct, and the algorithm does spit out random-looking arrangements. However this version is biased towards certain permutations. Specifically, if arr = [0,1,2], then first, fifth and sixth arrangements [0,1,2], [2,0,1], [2,1,0] are slightly less likely than the second, third and fourth arrangements [0,2,1], [1,0,2], [1,2,0].

Strangely, I implemented the bad version of Fisher-Yates and ran it to shuffle [0, 1, 2] 100,000 times. If all six permutations are equally likely, you’d get 1/6 = 0.1667 of each. In my demo I got [ 0.1665 0.1673 0.1641 0.1676 0.1682 0.1664]. This is not at all what theory predicts and I’m not sure what’s going on. I know from a lifetime of experience that this problem is now in the back of my mind and will stay there, nagging at me, until I solve it. But I will.



Left: Synchronized shuffle dancing. Center: The Shuffle Inn arcade bowling game. Right: Shuffling playing cards.


Demo code below.

# fisher_yates.py

import numpy as np

# -------------------------------------------------------------

def shuffle_basic(arr, rndobj):
  # last i iteration not necessary
  n = len(arr)
  for i in range(n):  # 0 to n-1 inclusive
    j = rndobj.randint(i, n)  # i to n-1 inclusive
    tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp

def shuffle_efficient(arr, rndobj):
  # don't do last i iteration
  n = len(arr)
  for i in range(n-1):  # 0 to n-2 inclusive
    j = rndobj.randint(i, n)  # i to n-1 inclusive
    tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp

def shuffle_bad(arr):
  # looks good but is biased towards some patterns
  n = len(arr)
  for i in range(n):  # 0 to n-1 inclusive
    j = np.random.randint(0, n)  # 0 to n-1 inclusive
    tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp

def shuffle_global_rnd(arr):
  # using a global Random makes reproducibility difficult
  n = len(arr)
  for i in range(n):  # 0 to n-1 inclusive
    j = np.random.randint(i, n)  # i to n-1 inclusive
    tmp = arr[i]; arr[i] = arr[j]; arr[j] = tmp  

# -------------------------------------------------------------

def main():
  print("\nBegin Fisher-Yates shuffle demo ")
  np.set_printoptions(formatter={'float': '{: 0.4f}'.format})

  rndobj = np.random.RandomState(0)

  arr = np.array([0, 1, 2, 3, 4], dtype=np.int64)
  print("\narray before shuffle: ")
  print(arr)

  print("\nShuffling five times using basic algorithm ")
  for i in range(5):
    shuffle_basic(arr, rndobj)
    print(arr)

# -------------------------------------------------------------

  rndobj = np.random.RandomState(0)

  arr = np.array([0, 1, 2, 3, 4], dtype=np.int64)
  print("\narray before shuffle: ")
  print(arr)

  print("\nShuffling five times using \"efficient\" technique ")
  for i in range(5):
    shuffle_efficient(arr, rndobj)
    print(arr)

# -------------------------------------------------------------

  rndobj = np.random.RandomState(0)

  arr = np.array([0, 1, 2], dtype=np.int64)
  print("\narray before shuffle: ")
  print(arr)
  
  counts = np.zeros(6, dtype=np.int64)
  print("\nShuffling 100_000 times using basic technique ") 
  for i in range(100_000):
    shuffle_basic(arr, rndobj)
    s = str(arr[0]) + str(arr[1]) + str(arr[2]) 
    if   s == "012": counts[0] += 1
    elif s == "021": counts[1] += 1
    elif s == "102": counts[2] += 1
    elif s == "120": counts[3] += 1
    elif s == "201": counts[4] += 1
    elif s == "210": counts[5] += 1

  freqs = (counts * 1.0) / np.sum(counts)
  print("\nResult frequencies (0.1667 expected): ")
  print(freqs)

# -------------------------------------------------------------

  np.random.seed(0)  # program Global

  arr = np.array([0, 1, 2], dtype=np.int64)
  print("\narray before shuffle: ")
  print(arr)

  counts = np.zeros(6, dtype=np.int64)
  print("\nShuffling 100_000 times using faulty technique ") 
  for i in range(100_000):
    shuffle_bad(arr)
    s = str(arr[0]) + str(arr[1]) + str(arr[2]) 
    if   s == "012": counts[0] += 1
    elif s == "021": counts[1] += 1
    elif s == "102": counts[2] += 1
    elif s == "120": counts[3] += 1
    elif s == "201": counts[4] += 1
    elif s == "210": counts[5] += 1

  freqs = (counts * 1.0) / np.sum(counts)
  print("\nResult frequencies (0.1667 expected): ")
  print(freqs)
  
  print("\nEnd demo ")

if __name__ == "__main__":
  main()
This entry was posted in Machine Learning. Bookmark the permalink.