Gibbs Sampling Example Using a Discrete Distribution

Gibbs sampling is a very complex topic because it involves about half a dozen ideas in probability, each of which is very complex. It’s not possible to completely understand Gibbs sampling with a single example. You need to look at several examples. Here’s one that uses a discrete distribution instead of the usual example with a continuous (Gaussian) distribution.

Suppose you have a coin and a spinner, and that they’re connected in some mysterious way so that the result of one depends upon the previous result of the other.

Now suppose that, unknown to you, the true joint probability distribution is:

P(Heads and 0) = 0.10
P(Heads and 1) = 0.20
P(Heads and 2) = 0.10
P(Tails and 0) = 0.30
P(Tails and 1) = 0.10
P(Tails and 2) = 0.20

But suppose you somehow know the two conditional probability distributions:

P(0 | Heads) = 1/4
P(1 | Heads) = 1/2
P(2 | Heads) = 1/4

P(0 | Tails) = 1/2
P(1 | Tails) = 1/6
P(2 | Tails) = 1/3 

---------------------

P(Heads | 0) = 1/4
P(Tails | 0) = 3/4

P(Heads | 1) = 2/3
P(Tails | 1) = 1/3

P(Heads | 2) = 1/3
P(Tails | 2) = 2/3

Gibbs sampling is an algorithm that can be used to estimate the true joint probability distribution from the conditional distributions. The key part of the algorithm, in Python code, looks like:

for i in range(n):
  for j in range(thin):
    c = pick_coin(s)  # 0 or 1
    s = pick_spinner(c)  # 0, 1, 2
  # increment correct counter

The purpose of the “thin” is to select pairs of values that weren’t generated close together. Function pick_coin() returns random Heads (0) or Tails (1), depending on the value of the spinner. Function pick_spinner() returns random 0 or 1 or 2 depending on the value of the coin.

My Gibbs demo iterates 1,000 times and gives a very close approximation to the true joint distribution.

If you’ve stumbled on this blog post while searching the Internet for an example of Gibbs sampling, I hope you pick up one more piece of the very complex puzzle.


# gibbs_discrete.py

import numpy as np
np.random.seed(0)

def pick_coin(spinner):
  p = np.random.random()  # [0.0, 1.0)
  if spinner == 0:
    if p < 0.250000: return 0
    else: return 1
  elif spinner == 1:
    if p < 0.666667: return 0
    else: return 1
  elif spinner == 2:
    if p < 0.333333: return 0
    else: return 1

def pick_spinner(coin):
  p = np.random.random()  # [0.0, 1.0)
  if coin == 0:
    if p < 0.250000: return 0
    elif p < 0.750000: return 1
    else: return 2
  elif coin == 1:
    if p < 0.500000: return 0
    elif p < 0.666667: return 1
    else: return 2  

def gibbs_sample(n=1000, thin=500):
  ct00 = 0; ct01 = 0; ct02 = 0;
  ct10 = 0; ct11 = 0; ct12 = 0;  
  c=0; s=0
  for i in range(n):
    for j in range(thin):
      c = pick_coin(s)  # 0 or 1
      s = pick_spinner(c)  # 0, 1, 2
    # print(c, s)
    if   c == 0 and s == 0: ct00 += 1
    elif c == 0 and s == 1: ct01 += 1
    elif c == 0 and s == 2: ct02 += 1
    elif c == 1 and s == 0: ct10 += 1
    elif c == 1 and s == 1: ct11 += 1
    elif c == 1 and s == 2: ct12 += 1

  print("P(00) = %0.4f" % (ct00 / n))
  print("P(01) = %0.4f" % (ct01 / n))
  print("P(02) = %0.4f" % (ct02 / n))
  print("P(10) = %0.4f" % (ct10 / n))
  print("P(11) = %0.4f" % (ct11 / n))
  print("P(12) = %0.4f" % (ct12 / n))


print("\nBegin \n")
print("True joint distribution: \n")
print("P(00) = 0.10")
print("P(01) = 0.20")
print("P(02) = 0.10")
print("P(10) = 0.30")
print("P(11) = 0.10")
print("P(12) = 0.20")

print("\nEstimated using Gibbs sampling: \n") 
gibbs_sample(1000, 500)

print("\nDone \n")
This entry was posted in Miscellaneous. Bookmark the permalink.