Another Look at Nearest Centroid Classification for Data with Numeric Predictors, Using C#

Several weeks ago, I put together a demo of nearest centroid classification for data that has strictly numeric predictor variables, using the C# language. ( https://jamesmccaffreyblog.com/2024/04/08/nearest-centroid-classification-from-scratch-using-csharp/ )

One morning before work, I decided to tidy up that example. I made many changes. The three most notable are the change from divide-by-constant normalization to min-max normalization, the addition of a PredictProbs() method, and the addition of a confusion matrix for model evaluation.

Briefly, in nearest centroid classification, the vector centroids (also called means or averages) in the training data are computed for each of the classes to predict. To classify a data item, the distance between the item and each centroid is computed. The predicted class is the class associated with the nearest centroid.

For the demo data, I used a small subset of the Penguin Dataset. The 30-item raw training data looks like:

2, 50.0, 16.3, 230.0, 5700.0
0, 39.1, 18.7, 181.0, 3750.0
1, 38.8, 17.2, 180.0, 3800.0
2, 39.3, 20.6, 190.0, 3650.0
. . .

The goal is to predict a species of penguin (0 = Adelie, 1 = Chinstrap, 2 = Gentoo) from bill length, bill width, flipper length, and body mass. After programmatic min-max normalization, the normalized data looks like:

2    0.851    0.257    1.000    0.941
0    0.249    0.600    0.020    0.176
1    0.232    0.386    0.000    0.196
2    0.260    0.871    0.200    0.137
. . .

where the predictor min and max values are:

mins = 34.6  14.5  180.0  3300.0
maxs = 52.7  21.5  230.0  5850.0

The training data is loaded into memory and normalized like so:

string trainFile = 
  "..\\..\\..\\Data\\penguin_train_30.txt";
double[][] trainX = MatLoad(trainFile,
  new int[] { 1, 2, 3, 4 }, ',', "#");
double[][] minsMaxs = MatMinMaxValues(trainX); 
trainX = MatNormalizeUsing(trainX, minsMaxs);
int[] trainY = VecLoad(trainFile, 0, "#");

The test data is loaded and normalized similarly:

string testFile = 
  "..\\..\\..\\Data\\penguin_test_10.txt";
double[][] testX = MatLoad(testFile,
  new int[] { 1, 2, 3, 4 }, ',', "#");
testX = MatNormalizeUsing(testX, minsMaxs);
int[] testY = VecLoad(testFile, 0, "#");

The nearest centroid classifier is instantiated and trained using these statements:

int numClasses = 3;
NearestCentroidClassifier ncc = 
  new NearestCentroidClassifier(numClasses);
ncc.Train(trainX, trainY);
Console.WriteLine("Class centroids: ");
MatShow(ncc.centroids, 4, 9, 3, true);

The trained model is evaluated like this:

double accTrain = ncc.Accuracy(trainX, trainY);
Console.WriteLine("Accuracy on train: " +
  accTrain.ToString("F4"));
double accTest = ncc.Accuracy(testX, testY);
Console.WriteLine("Accuracy on test: " +
  accTest.ToString("F4"));
int[][] cm = ncc.ConfusionMatrix(trainX, trainY);
ncc.ShowConfusion(cm);

The output is:

Accuracy on train: 0.9333
Accuracy on test: 1.0000

Confusion matrix for training data:
actual 0:    8    0    0  |    8 |  1.0000
actual 1:    1   12    0  |   13 |  0.9231
actual 2:    1    0    8  |    9 |  0.8889

The trained model scores 0.9333 accuracy (28 out of 30 correct) on the training data and 1.0000 accuracy (10 out of 10) on the test data. Nearest centroid classification often doesn’t work well, but for this data, in fact, I had to manipulate the training data by changing the class label values for two of the training items so that I could get any incorrect predictions at all.

My demo concludes by predicting the species for a new, previously unseen penguin:

Console.WriteLine("Prediciting species for " +
  " x = 46.5, 17.9, 192, 3500");
string[] speciesNames = 
  new string[] { "Adelie", "Chinstrap", "Gentoo" };
double[] xRaw = { 46.5, 17.9, 192, 3500 };
double[] xNorm = VecNormalizeUsing(xRaw, minsMaxs);
Console.Write("Normalized x =");
VecShow(xNorm, 4, 9);
int lbl = ncc.Predict(xNorm);
Console.WriteLine("predicted label/class = " + lbl);
Console.WriteLine("predicted species = " +
  speciesNames[lbl]);
double[] pseudoProbs = ncc.PredictProbs(xNorm);
Console.WriteLine("prediction pseudo-probs = ");
VecShow(pseudoProbs, 4, 9);

The output is:

Predicting species for x = 46.5, 17.9, 192, 3500
Normalized x =   0.6575   0.4857   0.2400   0.0784
predicted label/class = 1
predicted species = Chinstrap
prediction pseudo-probs =
   0.2791   0.5522   0.1686

Nearest centroid classification is arguably the simplest possible classification technique. Compared to other techniques, four advantages of nearest centroid classification (NCC) are that NCC is easy to implement, NCC can work with very small datasets, NCC is highly interpretable, and NCC works for both binary classification and multi-class classification. Two disadvantages of NCC are that NCC directly works only with strictly numeric predictor variables, and most importantly, NCC is the least powerful classification technique.

Nearest centroid classification isn’t as powerful as other classification techniques because it doesn’t deal with interactions between predictor variables. But nearest centroid classification is a good way to establish a baseline prediction model result. And, there are some situations, such as the predicting the Penguin Dataset, where nearest centroid classification is surprisingly powerful.



I don’t know much about art, but here are three nice illustrations that I’d say could be classified as having similar styles. Left: By artist Mark Swanson. Center: By artist Kirsten Ulve. Right: By artist Josh Agle.


Demo code. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols (my blog editor chokes on symbols).

using System;
using System.IO;

// simple centroids-based classification

namespace NearestCentroidClassification
{
  internal class NearestCentroidProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin nearest " +
        "centroid classification demo ");

      // 1. load and normalize training data
      Console.WriteLine("\nLoading penguin subset " +
        "train (30) and test (10) data ");
      string trainFile =
        "..\\..\\..\\Data\\penguin_train_30.txt";
      double[][] trainX = MatLoad(trainFile,
        new int[] { 1, 2, 3, 4 }, ',', "#");
      Console.WriteLine("\nX training raw: ");
      MatShow(trainX, 1, 9, 4, true);

      // get normalized X and mins-maxs
      Console.WriteLine("\nNormalizing train X" +
        " using min-max ");
      double[][] minsMaxs = MatMinMaxValues(trainX);
      trainX = MatNormalizeUsing(trainX, minsMaxs);
      Console.WriteLine("Done ");
      Console.WriteLine("\nX training normalized: ");
      MatShow(trainX, 4, 9, 4, true);

      // get the training data labels/classes/species
      int[] trainY = VecLoad(trainFile, 0, "#");
      Console.WriteLine("\nY training: ");
      VecShow(trainY, wid: 3);

      // 2. load and normalize test data
      Console.WriteLine("\nLoading and " +
        "normalizing test data ");
      string testFile =
        "..\\..\\..\\Data\\penguin_test_10.txt";
      double[][] testX = MatLoad(testFile,
        new int[] { 1, 2, 3, 4 }, ',', "#");
      testX = MatNormalizeUsing(testX, minsMaxs);
      int[] testY = VecLoad(testFile, 0, "#");
      Console.WriteLine("Done ");

      // 3. create and train classifier
      Console.WriteLine("\nCreating " +
        "NearestCentroidClassifier object ");
      int numClasses = 3;
      NearestCentroidClassifier ncc =
        new NearestCentroidClassifier(numClasses);
      Console.WriteLine("Training the classifier ");
      ncc.Train(trainX, trainY);
      Console.WriteLine("Done ");

      Console.WriteLine("\nClass centroids: ");
      MatShow(ncc.centroids, 4, 9, 3, true);

      // 4. evaluate model
      Console.WriteLine("\nEvaluating model ");
      double accTrain = ncc.Accuracy(trainX, trainY);
      Console.WriteLine("Accuracy on train: " +
        accTrain.ToString("F4"));

      double accTest = ncc.Accuracy(testX, testY);
      Console.WriteLine("Accuracy on test: " +
        accTest.ToString("F4"));

      Console.WriteLine("\nConfusion matrix" +
        " for training data: ");
      int[][] cm = ncc.ConfusionMatrix(trainX, trainY);
      ncc.ShowConfusion(cm);

      // 5. use model
      Console.WriteLine("\nPredicting species" +
        " for x = 46.5, 17.9, 192, 3500");

      string[] speciesNames = new string[] { "Adelie",
        "Chinstrap", "Gentoo" };
      double[] xRaw = { 46.5, 17.9, 192, 3500 };
      double[] xNorm = VecNormalizeUsing(xRaw, minsMaxs);
      Console.Write("Normalized x =");
      VecShow(xNorm, 4, 9);

      int lbl = ncc.Predict(xNorm);
      Console.WriteLine("predicted label/class = " + lbl);
      Console.WriteLine("predicted species = " + 
        speciesNames[lbl]);

      double[] pseudoProbs = ncc.PredictProbs(xNorm);
      Console.WriteLine("\nprediction pseudo-probs = ");
      VecShow(pseudoProbs, 4, 9);

      // 6. TODO: consider saving model (centroids)

      Console.WriteLine("\nEnd demo ");
      Console.ReadLine();
    } // Main

    // ------------------------------------------------------

    public static double[][] MatLoad(string fn,
      int[] usecols, char sep, string comment)
    {
      // count number of non-comment lines
      int nRows = 0;
      string line = "";
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
        if (line.StartsWith(comment) == false)
          ++nRows;
      sr.Close(); ifs.Close(); // could reset fp

      // make result matrix
      int nCols = usecols.Length;
      double[][] result = new double[nRows][];
      for (int r = 0; r "lt" nRows; ++r)
        result[r] = new double[nCols];

      line = "";
      string[] tokens = null;
      ifs = new FileStream(fn, FileMode.Open);
      sr = new StreamReader(ifs);

      int i = 0;
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        tokens = line.Split(sep);
        for (int j = 0; j "lt" nCols; ++j)
        {
          int k = usecols[j];  // into tokens
          result[i][j] = double.Parse(tokens[k]);
        }
        ++i;
      }
      sr.Close(); ifs.Close();
      return result;
    }

    // ------------------------------------------------------

    public static int[] VecLoad(string fn, int usecol,
      string comment)
    {
      char dummySep = ',';
      double[][] tmp = MatLoad(fn, new int[] { usecol },
        dummySep, comment);
      int n = tmp.Length;
      int[] result = new int[n];
      for (int i = 0; i "lt" n; ++i)
        result[i] = (int)tmp[i][0];
      return result;
    }

    // ------------------------------------------------------

    public static double[][] MatMinMaxValues(double[][] X)
    {
      // return min and max values for each column of X
      // mins on row[0] of result, maxs at row[1]

      int nRows = X.Length;
      int nCols = X[0].Length;

      double[][] result = new double[2][];
      for (int i = 0; i "lt" 2; ++i)
        result[i] = new double[nCols];

      for (int j = 0; j "lt" nCols; ++j)
      {
        double colMin = X[0][j];
        double colMax = X[0][j];

        for (int i = 0; i "lt" nRows; ++i)
        {
          if (X[i][j] "lt" colMin)
            colMin = X[i][j];
          if (X[i][j] "gt" colMax)
            colMax = X[i][j];
        }
        result[0][j] = colMin;
        result[1][j] = colMax;
      }

      return result;
    } // MatMinMaxValues

    // ------------------------------------------------------

    public static double[][] MatNormalizeUsing(double[][] X,
      double[][] minsMaxs)
    {
      // return normalized X, using mins and maxs
      int nRows = X.Length;
      int nCols = X[0].Length;
      double[][] result = new double[nRows][];
      for (int i = 0; i "lt" nRows; ++i)
        result[i] = new double[nCols];
      for (int j = 0; j "lt" nCols; ++j)
        for (int i = 0; i "lt" nRows; ++i)
          result[i][j] =
            (X[i][j] - minsMaxs[0][j]) /
            (minsMaxs[1][j] - minsMaxs[0][j]);
      return result;
    } // MatMinMaxNormalize using

    // ------------------------------------------------------

    public static double[] VecNormalizeUsing(double[] x,
      double[][] minsMaxs)
    {
      int dim = x.Length;
      double[] result = new double[dim];
      for (int j = 0; j "lt" dim; ++j)
        result[j] = 
          (x[j] - minsMaxs[0][j]) / 
          (minsMaxs[1][j] - minsMaxs[0][j]);
      return result;
    }

    // ------------------------------------------------------

    public static void MatShow(double[][] M, int dec,
      int wid, int numRows, bool showIndices)
    {
      double small = 1.0 / Math.Pow(10, dec);
      for (int i = 0; i "lt" numRows; ++i)
      {
        if (showIndices == true)
        {
          int pad = M.Length.ToString().Length;
          Console.Write("[" + i.ToString().
            PadLeft(pad) + "]");
        }
        for (int j = 0; j "lt" M[0].Length; ++j)
        {
          double v = M[i][j];
          if (Math.Abs(v) "lt" small) v = 0.0;
          Console.Write(v.ToString("F" + dec).
            PadLeft(wid));
        }
        Console.WriteLine("");
      }
      if (numRows "lt" M.Length) Console.WriteLine(". . .");
    }

    // ------------------------------------------------------

    public static void VecShow(int[] vec, int wid)
    {
      int n = vec.Length;
      for (int i = 0; i "lt" n; ++i)
      {
        if (i != 0 && i % 12 == 0) Console.WriteLine("");
        Console.Write(vec[i].ToString().PadLeft(wid));
      }
      Console.WriteLine("");
    }

    // ------------------------------------------------------

    public static void VecShow(int[] vec, int wid,
      int nItems)
    {
      //int n = vec.Length;
      for (int i = 0; i "lt" nItems; ++i)
      {
        if (i != 0 && i % 12 == 0) Console.WriteLine("");
        Console.Write(vec[i].ToString().PadLeft(wid));
      }
      Console.WriteLine("");
    }

    // ------------------------------------------------------

    public static void VecShow(double[] vec, int decimals,
      int wid)
    {
      int n = vec.Length;
      for (int i = 0; i "lt" n; ++i)
        Console.Write(vec[i].ToString("F" + decimals).
          PadLeft(wid));
      Console.WriteLine("");
    }

    // ------------------------------------------------------

  } // Program

  public class NearestCentroidClassifier
  {
    public int numClasses;
    public double[][] centroids;  // of each class

    public NearestCentroidClassifier(int numClasses)
    {
      this.numClasses = numClasses;
      this.centroids = new double[0][]; // keep compiler happy
    }

    public void Train(double[][] trainX, int[] trainY)
    {
      // compute centroid of each class
      int n = trainX.Length;
      int dim = trainX[0].Length;

      this.centroids = new double[this.numClasses][];
      for (int c = 0; c "lt" numClasses; ++c)
        this.centroids[c] = new double[dim];

      double[][] sums = new double[this.numClasses][];
      for (int c = 0; c "lt" numClasses; ++c)
        sums[c] = new double[dim];

      int[][] counts = new int[this.numClasses][];
      for (int c = 0; c "lt" numClasses; ++c)
        counts[c] = new int[dim];

      for (int i = 0; i "lt" n; ++i)
      {
        int c = trainY[i];
        for (int j = 0; j "lt" dim; ++j)
        {
          sums[c][j] += trainX[i][j];
          ++counts[c][j];
        }
      }

      for (int c = 0; c "lt" this.numClasses; ++c)
        for (int j = 0; j "lt" dim; ++j)
          this.centroids[c][j] = sums[c][j] / counts[c][j];

      // // less efficient but more clear
      //for (int c = 0; c "lt" this.numClasses; ++c)
      //{
      //  for (int j = 0; j "lt" dim; ++j) // each col
      //  {
      //    double colSum = 0.0;
      //    int colCount = 0;
      //    for (int i = 0; i "lt" n; ++i)  // each row
      //    {
      //      if (trainY[i] != c) continue;
      //      colSum += trainX[i][j];
      //      ++colCount;
      //    }
      //    this.means[c][j] = colSum / colCount;
      //  } // each col
      // } // each class

    } // Train

    // ------------------------------------------------------

    public int Predict(double[] x)
    {
      double[] distances = new double[this.numClasses];
      for (int c = 0; c "lt" this.numClasses; ++c)
        distances[c] = EucDistance(x, this.centroids[c]);
      double smallestDist = distances[0];
      int result = 0;
      for (int c = 0; c "lt" this.numClasses; ++c)
      {
        if (distances[c] "lt" smallestDist)
        {
          smallestDist = distances[c];
          result = c;
        }
      }
      return result;
    }

    // ------------------------------------------------------

    public double[] PredictProbs(double[] x)
    {
      double[] probs = new double[this.numClasses];
      double[] distances = new double[this.numClasses];
      double[] invDists = new double[this.numClasses];

      double sum = 0.0;  // of inverse distances
      for (int c = 0; c "lt" this.numClasses; ++c)
      {
        distances[c] = EucDistance(x, this.centroids[c]);
        if (distances[c] "lt" 0.00000001)
          distances[c] = 0.00000001;  // avoid div by 0
        invDists[c] = 1.0 / distances[c];
        sum += invDists[c];
      }
      for (int c = 0; c "lt" this.numClasses; ++c)
        probs[c] = invDists[c] / sum;
      return probs;  // pseudo-probabilities
    }

    // ------------------------------------------------------

    public double Accuracy(double[][] dataX, int[] dataY)
    {
      int nCorrect = 0;
      int nWrong = 0;
      int n = dataX.Length;
      for (int i = 0; i "lt" n; ++i)
      {
        int c = this.Predict(dataX[i]);
        //Console.WriteLine("actual = " + dataY[i]);
        //Console.WriteLine("predicted = " + c);
        //Console.ReadLine();
        if (c == dataY[i])
          ++nCorrect;
        else
          ++nWrong;
      }
      return (nCorrect * 1.0) / (nCorrect + nWrong);
    }

    // ------------------------------------------------------

    private double EucDistance(double[] v1, double[] v2)
    {
      int dim = v1.Length;
      double sum = 0.0;
      for (int d = 0; d "lt" dim; ++d)
        sum += (v1[d] - v2[d]) * (v1[d] - v2[d]);
      return Math.Sqrt(sum);
    }

    // ------------------------------------------------------

    public int[][] ConfusionMatrix(double[][] dataX,
      int[] dataY)
    {
      int n = this.numClasses;
      int[][] result = new int[n][];  // nxn
      for (int i = 0; i "lt" n; ++i)
        result[i] = new int[n];

      for (int i = 0; i "lt" dataX.Length; ++i)
      {
        double[] x = dataX[i];  // inputs
        int actualY = dataY[i];
        int predY = this.Predict(x);
        ++result[actualY][predY];
      }
      return result;
    }

    public void ShowConfusion(int[][] cm)
    {
      int n = cm.Length;
      int[] counts = new int[n];
      double[] accs = new double[n];
      for (int act = 0; act "lt" n; ++act) 
      {
        for (int pred = 0; pred "lt" n; ++pred)
        {
          counts[act] += cm[act][pred];
        }
      }

      for (int act = 0; act "lt" n; ++act)
      {
        accs[act] = (cm[act][act] * 1.0) / counts[act];
      }

      for (int i = 0; i "lt" n; ++i)
      {
        Console.Write("actual " + i + ": ");
        for (int j = 0; j "lt" n; ++j)
        {
          Console.Write(cm[i][j].ToString().
            PadLeft(4) + " ");
        }
        Console.Write(" | " + 
          counts[i].ToString().PadLeft(4));
        Console.Write(" | " + 
          accs[i].ToString("F4").PadLeft(7));
        Console.WriteLine("");
      }
    }

  } // class NearestCentroidClassifier

} // ns

// training data:
//
/*
# penguin_train_30.txt
# male only
# train mins = (34.60, 14.50, 180.00, 3300.00)
# train maxs = (52.70, 21.50, 230.00, 5850.00)
#
# 30 items
# actual: 0, 38.8, 17.2, 180.0, 3800.0
# replaced by: 1, 38.8, 17.2, 180.0, 3800.0
#
# actual: 0, 39.3, 20.6, 190.0, 3650.0
# replaced by 2, 39.3, 20.6, 190.0, 3650.0
# to generate at least one bad prediction
#
2, 50.0, 16.3, 230.0, 5700.0
0, 39.1, 18.7, 181.0, 3750.0
# 0, 38.8, 17.2, 180.0, 3800.0
1, 38.8, 17.2, 180.0, 3800.0
# 0, 39.3, 20.6, 190.0, 3650.0
2, 39.3, 20.6, 190.0, 3650.0
0, 39.2, 19.6, 195.0, 4675.0
2, 50.0, 15.2, 218.0, 5700.0
1, 50.0, 19.5, 196.0, 3900.0
1, 51.3, 19.2, 193.0, 3650.0
0, 38.6, 21.2, 191.0, 3800.0
1, 52.7, 19.8, 197.0, 3725.0
2, 47.6, 14.5, 215.0, 5400.0
1, 51.3, 18.2, 197.0, 3750.0
2, 46.7, 15.3, 219.0, 5200.0
1, 51.3, 19.9, 198.0, 3700.0
0, 34.6, 21.1, 198.0, 4400.0
1, 51.7, 20.3, 194.0, 3775.0
1, 52.0, 18.1, 201.0, 4050.0
2, 46.8, 15.4, 215.0, 5150.0
0, 42.5, 20.7, 197.0, 4500.0
1, 50.5, 19.6, 201.0, 4050.0
1, 50.3, 20.0, 197.0, 3300.0
2, 49.0, 16.1, 216.0, 5550.0
0, 46.0, 21.5, 194.0, 4200.0
1, 49.2, 18.2, 195.0, 4400.0
2, 48.4, 14.6, 213.0, 5850.0
2, 49.3, 15.7, 217.0, 5850.0
0, 37.7, 18.7, 180.0, 3600.0
0, 38.2, 18.1, 185.0, 3950.0
1, 48.5, 17.5, 191.0, 3400.0
1, 50.6, 19.4, 193.0, 3800.0
 */

// test data:
/*
# penguin_test_10.txt
# male only
# 10 items
0, 40.6, 18.6, 183.0, 3550.0
0, 40.5, 18.9, 180.0, 3950.0
0, 37.2, 18.1, 178.0, 3900.0
0, 40.9, 18.9, 184.0, 3900.0
1, 52.0, 19.0, 197.0, 4150.0
1, 49.5, 19.0, 200.0, 3800.0
1, 52.8, 20.0, 205.0, 4550.0
2, 49.2, 15.2, 221.0, 6300.0
2, 48.7, 15.1, 222.0, 5350.0
2, 50.2, 14.3, 218.0, 5700.0
 */
This entry was posted in Machine Learning. Bookmark the permalink.

3 Responses to Another Look at Nearest Centroid Classification for Data with Numeric Predictors, Using C#

  1. Thorsten Kleppe's avatar Thorsten Kleppe says:

    Very interesting post. NCC is similar to logistic regression. So I think both techniques can achieve similar accuracy. But maybe we need more experience to show that. It could be even better.

    One powerful idea is the representation space, where similar patterns are collected by training neural networks. K-means clustering is in a sense the representation space of a neural network layer from the input layer to the first hidden layer. So it may seem a bit weird at the moment, but building a deep centroid network, or whatever we should call it, seems promising, they could fill the entire space better than DNNs. But how? 🙂

  2. Vignes Anbah's avatar Vignes Anbah says:

    Hi absolutely great article. This article inspired me so much that I went on to create a production ready public facing tool based on what you talked about in this article.
    I’ve also given you and your blog due credit by linking your article from my site vedastro.org NCC Birth Time Predictor
    Thank you for your great work Dr. James D. McCaffrey

Comments are closed.