An Interesting Yin-Yang Dataset for Machine Learning

I was reviewing recent research papers on neuromorphic computing (a forward-looking concept where machine learning systems more closely resemble biological systems such as spiking input chains) and I came across an interesting dataset used in the paper “Fast and Energy-Efficient Neuromorphic Deep Learning with First-Spike Times”. The authors called their dataset the Yin-Yang dataset.


Red is class 0, Blue is class 1, Green is class 2

The data is generated programmatically. Each data item has four predictor values that look like [0.28, 0.66, 0.72, 0.34] and each item is one of three classes 0, 1, 2. When graphed, the data look likes a yin-yang symbol. Somewhat oddly, the first and third predictor values are ones-complements (0.28, 0.72 in the example above) and the second and four values are also complements. Therefore, only the first two predictor values have intrinsic meaning.

I went to the github project at github.com/lkriener/yin_yang_data_set and copy-pasted the code for the YinYangDataset class. After a bit of work, I got a demo program running where I generated 1000 data items and plotted them (using just the first predictor value for the x-axis and the second value for the y-axis).

Interesting and good fun.



The yin-yang pattern is somewhat of a design cliche but here are three examples that I think are pretty nice.


Demo code.

# yin_yang_demo.py


import numpy as np
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ----------------------------------------------------------------
        
class YinYangDataset(Dataset):
  # code copied directly from
  # github.com/lkriener/yin_yang_data_set

# ----------------------------------------------------------------

def main():
  print("\nBegin YinYangDataset demo ")
  print("\nGenerating 1000 data items ")

  ds = YinYangDataset(size=1000, seed=1)
  loader = DataLoader(ds, batch_size=1, shuffle=False)

  xs = []  # list to pass to plot
  ys = []
  cs = []

  for bix, batch in enumerate(loader):
    # bix is vatch index, 0 to 999
    X = batch[0]  # like [[0.6800, 0.4500, 0.3200, 0.5500]]
    c = batch[1]  # 0, 1, 2
    xs.append(X[0][0])
    ys.append(X[0][1])
    cs.append(c.item())

  xs = np.array(xs)  # convert for plot
  ys = np.array(ys)
  cs = np.array(cs)
    
  plt.scatter(xs[cs==0], ys[cs==0], color='red',
    edgecolor='black')
  plt.scatter(xs[cs==1], ys[cs==1], color='blue',
    edgecolor='black')
  plt.scatter(xs[cs==2], ys[cs==2], color='green',
    edgecolor='black')
  plt.show()

  print("\nEnd ")

if __name__ == "__main__":
  main()
This entry was posted in Miscellaneous. Bookmark the permalink.