Updated AdaBoost.R2 Regression Using C#

One morning before work, I decided to update my C# implementation of AdaBoost (“adaptive boosting”) regression. AdaBoost regression uses a collection of decision trees. Loosely, each tree is trained on a random selection of the training data, where the probability that an item is included in the current training set is proportional to how poorly predicted it is by previous trees.

More specifically, in very high-level pseudo-code:

train:

assign equal weights to all training items
create an empty list of decision trees (about 50 or 100)
for-each tree
  create a subset of training data where probability
   of inclusion is based on item's weight
  create and train a decision tree
  assign a confidence weight to new tree (better predictions
   give higher confidence weight)
  add newly trained tree to collection
  update training item weights (worse predictions give
   higher weight to item)
end-loop
 
predict:

for-each tree
  compute predicted y
end-loop 
return median of confidence-weight predictions

The original AdaBoost technique was developed in 1995 and is used for binary classification predictions. The original paper suggested how the technique could be modified to do regression. A 1997 research paper updated the suggestion for regression and it is called AdaBoost.R2 (adaptive boosting for regression version 2). As far as I know, adaptive boosting regression is only implemented using the AdaBoost.R2 algorithm so the terms “adaptive boosting regression”, “AdaBoost”, and “AdaBoost.R2” are all essentially synonymous.

The output of my demo:

Begin AdaBoost.R2 (tree) regression from scratch C#

Loading synthetic train (200) and test (40) data

First three train X:
 -0.1660  0.4406 -0.9998 -0.3953 -0.7065
  0.0776 -0.1616  0.3704 -0.5911  0.7562
 -0.9452  0.3409 -0.1654  0.1174 -0.7192

First three train y:
  0.4840
  0.1568
  0.8054

Setting maxLearners = 100
Setting tree maxDepth = 5
Setting tree minSamples = 2
Setting tree minLeaf = 1

Training AdaBoost.R2 model
Done
Created 100 learners

Accuracy train (within 0.10): 0.8250
Accuracy test (within 0.10): 0.5250

MSE train: 0.0004
MSE test: 0.0022

Predicting for x =
 -0.1660  0.4406 -0.9998 -0.3953 -0.7065
Predicted y = 0.4890

End demo

The results show that AdaBoost regresssion tends to overfit the data — good accuracy on the training data but not-as-good accuracy on new, previously unseen test data.

I validated my AdaBoost regression implementation by running the demo data through the AdaBoostRegressor module in the scikit-learn Python language library:

Setting n_estimators = 100
Setting max_depth = 5
Setting min_samples_split = 2
Setting min_samples_leaf = 1

Training AdaBoost.R2 model
Done
Created 100 learners

Accuracy train (within 0.10): 0.8200
Accuracy test (within 0.10): 0.5250

MSE train: 0.0002
MSE test: 0.0015

The results are pretty much the same. AdaBoost regression has a random component in the way training subsets are selected for tree, so results will differ for different implementations.

AdaBoost regression is related to gradient boosting regression (specifically via the XGBoost and LightGBM libraries). AdaBoost regression is not used nearly as often as gradient boosting. I suspect that, in practice, gradient boosting regression gives better (less over-fitting) results.



AI is revolutionizing all kinds of things, including machine learning code generation. AI is also revolutionizing art. Here are three images I found during a more-or-less random search for “vintage tiki art”. All appear to be generated by AI but I’m not sure — I like the art regardless of how it was created.


Demo program. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols (my blog editor often chokes on symbols).

using System;
using System.IO;
using System.Collections.Generic;

namespace AdaBoostRegression // AdaBoost.R2 algorithm
{
  internal class AdaBoostRegressionProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin AdaBoost.R2 (tree) " +
        "regression from scratch C# ");

      // 1. load data
      Console.WriteLine("\nLoading synthetic train (200)" +
        " and test (40) data");
      string trainFile =
        "..\\..\\..\\Data\\synthetic_train_200.txt";
      int[] colsX = new int[] { 0, 1, 2, 3, 4 };
      int colY = 5;

      double[][] trainX =
        MatLoad(trainFile, colsX, ',', "#");
      double[] trainY =
        MatToVec(MatLoad(trainFile,
        new int[] { colY }, ',', "#"));

      string testFile =
        "..\\..\\..\\Data\\synthetic_test_40.txt";
      double[][] testX =
        MatLoad(testFile, colsX, ',', "#");
      double[] testY =
        MatToVec(MatLoad(testFile,
        new int[] { colY }, ',', "#"));

      Console.WriteLine("\nFirst three train X: ");
      for (int i = 0; i "lt" 3; ++i)
        VecShow(trainX[i], 4, 8);

      Console.WriteLine("\nFirst three train y: ");
      for (int i = 0; i "lt" 3; ++i)
        Console.WriteLine(trainY[i].ToString("F4").
          PadLeft(8));

      // 2. create and train model
      int maxLearners = 100;
      int maxDepth = 5;
      int minSamples = 2;
      int minLeaf = 1;

      Console.WriteLine("\nSetting maxLearners = " +
        maxLearners);
      Console.WriteLine("Setting tree maxDepth = " +
        maxDepth);
      Console.WriteLine("Setting tree minSamples = " +
        minSamples);
      Console.WriteLine("Setting tree minLeaf = " +
        minLeaf);

      Console.WriteLine("\nTraining AdaBoost.R2 model ");
      AdaBoostRegressor model =
        new AdaBoostRegressor(maxLearners, maxDepth,
        minSamples, minLeaf, seed: 0); // master seed
      model.Train(trainX, trainY);
      Console.WriteLine("Done ");
      Console.WriteLine("Created " +
        model.learners.Count() + " learners ");

      // 3. evaluate model
      double accTrain = model.Accuracy(trainX, trainY, 0.10);
      Console.WriteLine("\nAccuracy train (within 0.10): " +
        accTrain.ToString("F4"));
      double accTest = model.Accuracy(testX, testY, 0.10);
      Console.WriteLine("Accuracy test (within 0.10): " +
        accTest.ToString("F4"));

      double mseTrain = model.MSE(trainX, trainY);
      Console.WriteLine("\nMSE train: " +
        mseTrain.ToString("F4"));
      double mseTest = model.MSE(testX, testY);
      Console.WriteLine("MSE test: " +
        mseTest.ToString("F4"));

      // 4. use model to make a prediction
      double[] x = trainX[0];
      Console.WriteLine("\nPredicting for x = ");
      VecShow(x, 4, 8);
      double yPred = model.Predict(x);
      Console.WriteLine("Predicted y = " +
        yPred.ToString("F4"));

      Console.WriteLine("\nEnd demo ");
      Console.ReadLine();
    } // Main()

    // ------------------------------------------------------
    // helpers for Main():
    //   MatLoad(), MatToVec(), VecShow().
    // ------------------------------------------------------

    static double[][] MatLoad(string fn, int[] usecols,
      char sep, string comment)
    {
      List"lt"double[]"gt" result =
        new List"lt"double[]"gt"();
      string line = "";
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        string[] tokens = line.Split(sep);
        List"lt"double"gt" lst = new List"lt"double"gt"();
        for (int j = 0; j "lt" usecols.Length; ++j)
          lst.Add(double.Parse(tokens[usecols[j]]));
        double[] row = lst.ToArray();
        result.Add(row);
      }
      sr.Close(); ifs.Close();
      return result.ToArray();
    }

    static double[] MatToVec(double[][] M)
    {
      int nRows = M.Length;
      int nCols = M[0].Length;
      double[] result = new double[nRows * nCols];
      int k = 0;
      for (int i = 0; i "lt" nRows; ++i)
        for (int j = 0; j "lt" nCols; ++j)
          result[k++] = M[i][j];
      return result;
    }

    static void VecShow(double[] vec, int dec, int wid)
    {
      for (int i = 0; i "lt" vec.Length; ++i)
        Console.Write(vec[i].ToString("F" + dec).
          PadLeft(wid));
      Console.WriteLine("");
    }

  } // class Program

  // ========================================================

  class AdaBoostRegressor
  {
    public int maxLearners;
    public int maxDepth;
    public int minSamples;
    public int minLeaf;
    private Random rnd;
    private int seed; // master also passed to each tree
    public List"lt"DecisionTreeRegressor"gt" learners;
    public List"lt"double"gt" betas;

    public AdaBoostRegressor(int maxLearners, int maxDepth,
      int minSamples, int minLeaf, int seed)
    {
      this.maxLearners = maxLearners;
      this.maxDepth = maxDepth;
      this.minSamples = minSamples;
      this.minLeaf = minLeaf;
      this.rnd = new Random(seed);
      this.seed = seed;
      this.learners = new List"lt"DecisionTreeRegressor"gt"();
      this.betas = new List"lt"double"gt"();
      // each learner has a "model weight" derived
      // from its beta-confidence weight
    } // ctor

    public void Train(double[][] trainX, double[] trainY)
    {
      int N = trainX.Length;
      double[] weights = new double[N];
      for (int i = 0; i "lt" N; ++i)
        weights[i] = 1.0 / N;

      for (int t = 0; t "lt" this.maxLearners; ++t)
      {
        // 1. use weights to select a sample
        int[] sampleIdxs = this.SelectIdxs(weights, N);

        double[][] sampleX = ExtractRows(trainX,
          sampleIdxs);
        double[] sampleY = ExtractVals(trainY,
          sampleIdxs);

        // 2. construct regression machine t
        DecisionTreeRegressor lrnr =
          new DecisionTreeRegressor(this.maxDepth,
          this.minSamples, this.minLeaf, seed:this.seed);
        lrnr.Train(sampleX, sampleY);

        // 3. compute predicteds for all training items
        double[] predY = new double[N];
        for (int i = 0; i "lt" N; ++i)
          predY[i] = lrnr.Predict(trainX[i]);

        // 4. calculate loss each sample
        double[] absDiffs = new double[N];
        for (int i = 0; i "lt" N; ++i)
        {
          absDiffs[i] = Math.Abs(predY[i] - trainY[i]);
        }

        double D = 0.0;  // largest abs diff
        for (int i = 0; i "lt" N; ++i)
        {
          if (absDiffs[i] "gt" D)
            D = absDiffs[i];
        }

        double[] L = new double[N];
        for (int i = 0; i "lt" N; ++i)
          L[i] = absDiffs[i] / D;  // normalized

        // 5. calculate average loss
        double Lbar = 0.0;
        for (int i = 0; i "lt" N; ++i)
          Lbar += L[i] * weights[i];

        if (Lbar "gte" 0.5)  // bad learner
          continue;

        // 6. compute beta (low beta = high confidence)
        this.learners.Add(lrnr);  // lrnr is good
        double beta = Lbar / (1.0 - Lbar);
        this.betas.Add(beta); // each learner has a beta

        // 7. update training weights
        double sum = 0.0;
        for (int i = 0; i "lt" N; ++i)
        {
          weights[i] *= Math.Pow(beta, (1.0 - Lbar));
          sum += weights[i];
        }
        for (int i = 0; i "lt" N; ++i)
          weights[i] /= sum;

      } // t

    } // Train()

    public double Predict(double[] x)
    {
      // each learner makes a prediction
      double[] preds = new double[this.learners.Count];
      for (int i = 0; i "lt" this.learners.Count; ++i)
      {
        preds[i] = this.learners[i].Predict(x);
      }

      // each learner has a "model weight" derived
      // from its beta-confidence weight
      double[] modelWts = new double[this.learners.Count];
      for (int i = 0; i "lt" this.learners.Count; ++i)
        modelWts[i] = Math.Log(1.0 / this.betas[i]);

      // final prediction is weighted median of 
      // the predictions
      double result = WeightedMedian(preds, modelWts);
      //double result = WeightedMean(preds, modelWts);

      return result;
    } // Predict()

    public double Accuracy(double[][] dataX,
        double[] dataY, double pctClose)
    {
      int nCorrect = 0; int nWrong = 0;
      for (int i = 0; i "lt" dataX.Length; ++i)
      {
        double predY = this.Predict(dataX[i]);
        double actuaY = dataY[i];
        if (Math.Abs(predY - actuaY) "lt"
          Math.Abs(pctClose * actuaY))
          ++nCorrect;
        else
          ++nWrong;
      }
      return (nCorrect * 1.0) / (nCorrect + nWrong);
    }

    // ------------------------------------------------------

    public double MSE(double[][] dataX,
      double[] dataY)
    {
      int n = dataX.Length;
      double sum = 0.0;
      for (int i = 0; i "lt" n; ++i)
      {
        double actualY = dataY[i];
        double predY = this.Predict(dataX[i]);
        sum += (actualY - predY) * (actualY - predY);
      }
      return sum / n;
    }

    // ------------------------------------------------------
    // helper functions for AdaBoostRegressor
    // ------------------------------------------------------

    //private static double WeightedMean(double[] values,
    //  double[] weights)
    //{
    //  int n = values.Length;
    //  double sumWts = 0.0;
    //  double sum = 0.0;
    //  for (int i = 0; i "lt" n; ++i)
    //  {
    //    sumWts += weights[i];
    //    sum += values[i] * weights[i];
    //  }
    //  return sum / sumWts;
    //}

    // ------------------------------------------------------

    private static double WeightedMedian(double[] values,
      double[] weights)
    {
      // no interpolation for even n
      // don't assume weights sum to 1.0
      int n = values.Length;
      double sumWts = 0.0;
      for (int i = 0; i "lt" n; ++i)
        sumWts += weights[i];
      double thresh = sumWts / 2;
      int[] sortedIdxs = ArgSort(values);

      double accum = 0.0;
      for (int j = 0; j "lt" n; ++j)
      {
        accum += weights[sortedIdxs[j]];
        if (accum "gte" thresh)
          return values[sortedIdxs[j]];
      }
      return values[sortedIdxs[n - 1]];
    }

    // helper for WeightedMedian()
    private static int[] ArgSort(double[] values)
    {
      int n = values.Length;
      double[] copy = new double[n];
      int[] indices = new int[n];
      for (int i = 0; i "lt" n; ++i)
      {
        copy[i] = values[i];
        indices[i] = i;
      }
      Array.Sort(copy, indices);  // in parallel
      return indices;
    }

    private static double[][] MatMake(int nRows, int nCols)
    {
      double[][] result = new double[nRows][];
      for (int i = 0; i "lt" nRows; ++i)
        result[i] = new double[nCols];
      return result;
    }

    private int SelectIdx(double[] probs)
    {
      // roulette wheel selection
      double p = this.rnd.NextDouble();
      double accum = 0.0;
      for (int i = 0; i "lt" probs.Length; ++i)
      {
        accum += probs[i];
        if (p "lt" accum)
          return i;
      }
      return probs.Length - 1;
    }

    private int[] SelectIdxs(double[] probs, int n)
    {
      // with replacement
      int[] result = new int[n];
      for (int i = 0; i "lt" n; ++i)
        result[i] = this.SelectIdx(probs);
      return result;
    }

    private static double[][] ExtractRows(double[][] mat,
      int[] idxs)
    {
      int nRows = idxs.Length;  // of the result
      int nCols = mat[0].Length;  // source and result
      double[][] result = MatMake(nRows, nCols);

      for (int i = 0; i "lt" nRows; ++i)
      {
        for (int j = 0; j "lt" nCols; ++j)
        {
          int srcRow = idxs[i];
          result[i][j] = mat[srcRow][j];
        }
      }
      return result;
    }

    private static double[] ExtractVals(double[] vec,
      int[] idxs)
    {
      int nVals = idxs.Length;
      double[] result = new double[nVals];
      for (int i = 0; i "lt" nVals; ++i)
      {
        int srcIdx = idxs[i];
        result[i] = vec[srcIdx];
      }
      return result;
    }

  } // class AdaBoostRegressor

  // ========================================================

  public class DecisionTreeRegressor
  {
    public int maxDepth;
    public int minSamples;  // aka min_samples_split
    public int minLeaf;  // min number of values in a leaf
    public int numSplitCols; // mostly for random forest
    public List"lt"Node"gt" tree = new List"lt"Node"gt"();
    public Random rnd;  // order in which cols are searched

    public double[][] trainX;  // store data by ref
    public double[] trainY;

    // ------------------------------------------------------

    public class Node
    {
      public int id;
      public int colIdx;  // aka featureIdx
      public double thresh;
      public int left;  // index into List
      public int right;
      public double value;
      public bool isLeaf;
      public List"lt"int"gt" rows;  // rows in train data

      public Node()
      {
        this.id = -1;
        this.colIdx = -1;
        this.thresh = 0.0;  // aka split value
        this.left = -1;
        this.right = -1;
        this.value = 0.0;  // aka pred y
        this.isLeaf = false;
        this.rows = null;
      }
    } // class Node

    // --------------------------------------------

    public DecisionTreeRegressor(int maxDepth = 2,
      int minSamples = 2, int minLeaf = 1,
      int numSplitCols = -1, int seed = 0)
    {
      // if maxDepth = 0, tree has just a root node
      // if maxDepth = 1, at most 3 nodes (root, l, r)
      // if maxDepth = n, at most 2^(n+1) - 1 nodes
      this.maxDepth = maxDepth;
      this.minSamples = minSamples;
      this.minLeaf = minLeaf;
      this.numSplitCols = numSplitCols;  // for ran. forest

      // create full tree List with null nodes
      int numNodes = (int)Math.Pow(2, (maxDepth + 1)) - 1;
      for (int i = 0; i "lt" numNodes; ++i)
      {
        this.tree.Add(null);  // empty nodes
      }
      this.rnd = new Random(seed);
    }

    // ------------------------------------------------------
    // public: Train(), Predict().
    // helpers: MakeTree(), BestSplit(), TreeTargetMean(),
    //   TreeTargetVariance().
    // ------------------------------------------------------

    public void Train(double[][] trainX, double[] trainY)
    {
      this.trainX = trainX; // 
      this.trainY = trainY;
      this.MakeTree();

      // optionally delete rows in each Node to save space
      // when tree is part of an ensemble
      for (int i = 0; i "lt" this.tree.Count; ++i)
        if (this.tree[i] != null)
          this.tree[i].rows = null;
    }

    // ------------------------------------------------------

    public double Predict(double[] x)
    {
      int p = 0;
      Node currNode = this.tree[p];
      while (currNode != null &&
        currNode.isLeaf == false &&
        p "lt" this.tree.Count)
      {
        if (x[currNode.colIdx] "lte" currNode.thresh)
          p = currNode.left;
        else
          p = currNode.right;
        currNode = this.tree[p];
      }
      return this.tree[p].value;
    }

    // ------------------------------------------------------

    // helpers: MakeTree(), BestSplit(),
    // TreeTargetMean(), TreeTargetVariance()

    private void MakeTree()
    {
      // no recursion, no pointers, List storage, no stack
      if (this.numSplitCols == -1) // use all cols
        this.numSplitCols = this.trainX[0].Length;

      // prepare root node
      List"lt"int"gt" allRows = new List"lt"int"gt"();
      for (int i = 0; i "lt" this.trainX.Length; ++i)
        allRows.Add(i);
      double grandMean = this.TreeTargetMean(allRows);

      // wait to supply colIdx and thresh in loop
      Node root = new Node();
      root.id = 0;
      root.left = 1;
      root.right = 2;
      root.value = grandMean;
      root.isLeaf = false; // already set
      root.rows = allRows;
      this.tree[0] = root;

      for (int i = 0; i "lt" this.tree.Count; ++i)
      {
        Node currNode = this.tree[i];
        // curr node has values except colIdx and thresh

        // curr node too deep to have children OR
        // curr node not enough rows to split then
        // leave both children as null
        if (currNode == null ||
          currNode.rows.Count == 0) { continue; }

        // if parent cannot be split, make parent a leaf
        if (currNode.id "gte" (int)Math.Pow(2,
          (this.maxDepth)) - 1 ||
          currNode.rows.Count "lt" this.minSamples)
        {
          currNode.isLeaf = true;
          continue;
        }

        // parent has enough rows to try to split
        double[] splitInfo = this.BestSplit(currNode.rows);
        int colIdx = (int)splitInfo[0];
        double splitVal = splitInfo[1]; //split value

        if (colIdx == -1)  // unable to split, parent leaf
        {
          currNode.isLeaf = true;
          continue;
        }

        // complete the fields for curr node
        currNode.colIdx = colIdx;
        currNode.thresh = splitVal;

        // construct the children, except colIdx and thresh
        Node leftNode = new Node();
        Node rightNode = new Node();

        // construct the children rows using split info
        // all info except colIdx and thresh
        List"lt"int"gt" leftIdxs = new List"lt"int"gt"();
        List"lt"int"gt" rightIdxs = new List"lt"int"gt"();
        for (int k = 0; k "lt" currNode.rows.Count; ++k)
        {
          int r = currNode.rows[k];
          if (this.trainX[r][colIdx] "lte" splitVal)
            leftIdxs.Add(r);
          else
            rightIdxs.Add(r);
        }

        // assign -1 to children if out of range
        leftNode.id = currNode.id * 2 + 1;
        if (leftNode.id "gt" (int)Math.Pow(2,
          (maxDepth + 1)) - 2) leftNode.id = -1;
        leftNode.left = leftNode.id * 2 + 1;
        if (leftNode.left "gt" (int)Math.Pow(2,
          (maxDepth + 1)) - 2) leftNode.left = -1;
        leftNode.right = leftNode.id * 2 + 2;
        if (leftNode.right "gt" (int)Math.Pow(2,
          (maxDepth + 1)) - 2) leftNode.right = -1;
        leftNode.rows = leftIdxs;
        leftNode.value =
          this.TreeTargetMean(leftNode.rows);
        this.tree[leftNode.id] = leftNode;

        rightNode.id = currNode.id * 2 + 2;
        if (rightNode.id "gt" (int)Math.Pow(2,
          (maxDepth + 1)) - 2) rightNode.id = -1;
        rightNode.left = rightNode.id * 2 + 1;
        if (rightNode.left "gt" (int)Math.Pow(2,
          (maxDepth + 1)) - 2) rightNode.left = -1;
        rightNode.right = rightNode.id * 2 + 2;
        if (rightNode.right "gt" (int)Math.Pow(2,
          (maxDepth + 1)) - 2) rightNode.right = -1;
        rightNode.rows = rightIdxs;
        rightNode.value =
          this.TreeTargetMean(rightNode.rows);
        this.tree[rightNode.id] = rightNode;

      } // i
      return;
    }

    // ------------------------------------------------------

    private double[] BestSplit(List"lt"int"gt" rows)
    {
      // implicit params numSplitCols, minLeaf, numsplitCols
      // result[0] = best col idx (as double)
      // result[1] = best split value
      rows.Sort();

      int bestColIdx = -1;  // indicates bad split
      double bestThresh = 0.0;
      double bestVar = double.MaxValue; // smaller better

      int nRows = rows.Count;  // or dataY.Length
      int nCols = this.trainX[0].Length;

      if (nRows == 0)
      {
        throw new Exception("empty data in BestSplit()");
      }

      // process cols in scrambled order
      int[] colIndices = new int[nCols];
      for (int k = 0; k "lt" nCols; ++k)
        colIndices[k] = k;
      // shuffle, inline Fisher-Yates
      int n = colIndices.Length;
      for (int i = 0; i "lt" n; ++i)
      {
        int ri = rnd.Next(i, n);  // be careful
        int tmp = colIndices[i];
        colIndices[i] = colIndices[ri];
        colIndices[ri] = tmp;
      }

      // numSplitCols is usually all columns (-1)
      for (int j = 0; j "lt" this.numSplitCols; ++j)
      {
        int colIdx = colIndices[j];
        HashSet"lt"double"gt" examineds = 
          new HashSet"lt"double"gt"();

        for (int i = 0; i "lt" nRows; ++i) // each row
        {
          // if curr thresh been seen, skip it
          double thresh = this.trainX[rows[i]][colIdx];
          if (examineds.Contains(thresh)) continue;
          examineds.Add(thresh);

          // get row idxs where x is lte, gt thresh
          List"lt"int"gt" leftIdxs = new List"lt"int"gt"();
          List"lt"int"gt" rightIdxs = new List"lt"int"gt"();
          for (int k = 0; k "lt" nRows; ++k)
          {
            if (this.trainX[rows[k]][colIdx] "lte" thresh)
              leftIdxs.Add(rows[k]);
            else
              rightIdxs.Add(rows[k]);
          }

          // Check if proposed split has too few values
          if (leftIdxs.Count "lt" this.minLeaf ||
            rightIdxs.Count "lt" this.minLeaf)
            continue;  // to next row

          double leftVar =
            this.TreeTargetVariance(leftIdxs);
          double rightVar =
            this.TreeTargetVariance(rightIdxs);
          double weightedVar = (leftIdxs.Count * leftVar +
            rightIdxs.Count * rightVar) / nRows;

          if (weightedVar "lt" bestVar)
          {
            // if this never happens, bestColIdx remains -1
            // which means a bad split. used in MakeTree()
            bestColIdx = colIdx;
            bestThresh = thresh;
            bestVar = weightedVar;
          }

        } // each row
      } // j each col

      double[] result = new double[2];  // out params ugly
      result[0] = 1.0 * bestColIdx;
      result[1] = bestThresh;
      return result;

    } // BestSplit()

    // ------------------------------------------------------

    private double TreeTargetMean(List"lt"int"gt" rows)
    {
      // mean of rows items in trainY: for node prediction
      double sum = 0.0;
      for (int i = 0; i "lt" rows.Count; ++i)
      {
        int r = rows[i];
        sum += this.trainY[r];
      }
      return sum / rows.Count;
    }

    // ------------------------------------------------------

    private double TreeTargetVariance(List"lt"int"gt" rows)
    {
      double mean = this.TreeTargetMean(rows);
      double sum = 0.0;
      for (int i = 0; i "lt" rows.Count; ++i)
      {
        int r = rows[i];
        sum += (this.trainY[r] - mean) *
          (this.trainY[r] - mean);
      }
      return sum / rows.Count;
    }

    // ------------------------------------------------------

  } // class DecisionTreeRegressor

  // ========================================================

} // ns

Training data:

# synthetic_train_200.txt
#
-0.1660,  0.4406, -0.9998, -0.3953, -0.7065,  0.4840
 0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.1568
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192,  0.8054
 0.9365, -0.3732,  0.3846,  0.7528,  0.7892,  0.1345
-0.8299, -0.9219, -0.6603,  0.7563, -0.8033,  0.7955
 0.0663,  0.3838, -0.3690,  0.3730,  0.6693,  0.3206
-0.9634,  0.5003,  0.9777,  0.4963, -0.4391,  0.7377
-0.1042,  0.8172, -0.4128, -0.4244, -0.7399,  0.4801
-0.9613,  0.3577, -0.5767, -0.4689, -0.0169,  0.6861
-0.7065,  0.1786,  0.3995, -0.7953, -0.1719,  0.5569
 0.3888, -0.1716, -0.9001,  0.0718,  0.3276,  0.2500
 0.1731,  0.8068, -0.7251, -0.7214,  0.6148,  0.3297
-0.2046, -0.6693,  0.8550, -0.3045,  0.5016,  0.2129
 0.2473,  0.5019, -0.3022, -0.4601,  0.7918,  0.2613
-0.1438,  0.9297,  0.3269,  0.2434, -0.7705,  0.5171
 0.1568, -0.1837, -0.5259,  0.8068,  0.1474,  0.3307
-0.9943,  0.2343, -0.3467,  0.0541,  0.7719,  0.5581
 0.2467, -0.9684,  0.8589,  0.3818,  0.9946,  0.1092
-0.6553, -0.7257,  0.8652,  0.3936, -0.8680,  0.7018
 0.8460,  0.4230, -0.7515, -0.9602, -0.9476,  0.1996
-0.9434, -0.5076,  0.7201,  0.0777,  0.1056,  0.5664
 0.9392,  0.1221, -0.9627,  0.6013, -0.5341,  0.1533
 0.6142, -0.2243,  0.7271,  0.4942,  0.1125,  0.1661
 0.4260,  0.1194, -0.9749, -0.8561,  0.9346,  0.2230
 0.1362, -0.5934, -0.4953,  0.4877, -0.6091,  0.3810
 0.6937, -0.5203, -0.0125,  0.2399,  0.6580,  0.1460
-0.6864, -0.9628, -0.8600, -0.0273,  0.2127,  0.5387
 0.9772,  0.1595, -0.2397,  0.1019,  0.4907,  0.1611
 0.3385, -0.4702, -0.8673, -0.2598,  0.2594,  0.2270
-0.8669, -0.4794,  0.6095, -0.6131,  0.2789,  0.4700
 0.0493,  0.8496, -0.4734, -0.8681,  0.4701,  0.3516
 0.8639, -0.9721, -0.5313,  0.2336,  0.8980,  0.1412
 0.9004,  0.1133,  0.8312,  0.2831, -0.2200,  0.1782
 0.0991,  0.8524,  0.8375, -0.2102,  0.9265,  0.2150
-0.6521, -0.7473, -0.7298,  0.0113, -0.9570,  0.7422
 0.6190, -0.3105,  0.8802,  0.1640,  0.7577,  0.1056
 0.6895,  0.8108, -0.0802,  0.0927,  0.5972,  0.2214
 0.1982, -0.9689,  0.1870, -0.1326,  0.6147,  0.1310
-0.3695,  0.7858,  0.1557, -0.6320,  0.5759,  0.3773
-0.1596,  0.3581,  0.8372, -0.9992,  0.9535,  0.2071
-0.2468,  0.9476,  0.2094,  0.6577,  0.1494,  0.4132
 0.1737,  0.5000,  0.7166,  0.5102,  0.3961,  0.2611
 0.7290, -0.3546,  0.3416, -0.0983, -0.2358,  0.1332
-0.3652,  0.2438, -0.1395,  0.9476,  0.3556,  0.4170
-0.6029, -0.1466, -0.3133,  0.5953,  0.7600,  0.4334
-0.4596, -0.4953,  0.7098,  0.0554,  0.6043,  0.2775
 0.1450,  0.4663,  0.0380,  0.5418,  0.1377,  0.2931
-0.8636, -0.2442, -0.8407,  0.9656, -0.6368,  0.7429
 0.6237,  0.7499,  0.3768,  0.1390, -0.6781,  0.2185
-0.5499,  0.1850, -0.3755,  0.8326,  0.8193,  0.4399
-0.4858, -0.7782, -0.6141, -0.0008,  0.4572,  0.4197
 0.7033, -0.1683,  0.2334, -0.5327, -0.7961,  0.1776
 0.0317, -0.0457, -0.6947,  0.2436,  0.0880,  0.3345
 0.5031, -0.5559,  0.0387,  0.5706, -0.9553,  0.3107
-0.3513,  0.7458,  0.6894,  0.0769,  0.7332,  0.3170
 0.2205,  0.5992, -0.9309,  0.5405,  0.4635,  0.3532
-0.4806, -0.4859,  0.2646, -0.3094,  0.5932,  0.3202
 0.9809, -0.3995, -0.7140,  0.8026,  0.0831,  0.1600
 0.9495,  0.2732,  0.9878,  0.0921,  0.0529,  0.1289
-0.9476, -0.6792,  0.4913, -0.9392, -0.2669,  0.5966
 0.7247,  0.3854,  0.3819, -0.6227, -0.1162,  0.1550
-0.5922, -0.5045, -0.4757,  0.5003, -0.0860,  0.5863
-0.8861,  0.0170, -0.5761,  0.5972, -0.4053,  0.7301
 0.6877, -0.2380,  0.4997,  0.0223,  0.0819,  0.1404
 0.9189,  0.6079, -0.9354,  0.4188, -0.0700,  0.1907
-0.1428, -0.7820,  0.2676,  0.6059,  0.3936,  0.2790
 0.5324, -0.3151,  0.6917, -0.1425,  0.6480,  0.1071
-0.8432, -0.9633, -0.8666, -0.0828, -0.7733,  0.7784
-0.9444,  0.5097, -0.2103,  0.4939, -0.0952,  0.6787
-0.0520,  0.6063, -0.1952,  0.8094, -0.9259,  0.4836
 0.5477, -0.7487,  0.2370, -0.9793,  0.0773,  0.1241
 0.2450,  0.8116,  0.9799,  0.4222,  0.4636,  0.2355
 0.8186, -0.1983, -0.5003, -0.6531, -0.7611,  0.1511
-0.4714,  0.6382, -0.3788,  0.9648, -0.4667,  0.5950
 0.0673, -0.3711,  0.8215, -0.2669, -0.1328,  0.2677
-0.9381,  0.4338,  0.7820, -0.9454,  0.0441,  0.5518
-0.3480,  0.7190,  0.1170,  0.3805, -0.0943,  0.4724
-0.9813,  0.1535, -0.3771,  0.0345,  0.8328,  0.5438
-0.1471, -0.5052, -0.2574,  0.8637,  0.8737,  0.3042
-0.5454, -0.3712, -0.6505,  0.2142, -0.1728,  0.5783
 0.6327, -0.6297,  0.4038, -0.5193,  0.1484,  0.1153
-0.5424,  0.3282, -0.0055,  0.0380, -0.6506,  0.6613
 0.1414,  0.9935,  0.6337,  0.1887,  0.9520,  0.2540
-0.9351, -0.8128, -0.8693, -0.0965, -0.2491,  0.7353
 0.9507, -0.6640,  0.9456,  0.5349,  0.6485,  0.1059
-0.0462, -0.9737, -0.2940, -0.0159,  0.4602,  0.2606
-0.0627, -0.0852, -0.7247, -0.9782,  0.5166,  0.2977
 0.0478,  0.5098, -0.0723, -0.7504, -0.3750,  0.3335
 0.0090,  0.3477,  0.5403, -0.7393, -0.9542,  0.4415
-0.9748,  0.3449,  0.3736, -0.1015,  0.8296,  0.4358
 0.2887, -0.9895, -0.0311,  0.7186,  0.6608,  0.2057
 0.1570, -0.4518,  0.1211,  0.3435, -0.2951,  0.3244
 0.7117, -0.6099,  0.4946, -0.4208,  0.5476,  0.1096
-0.2929, -0.5726,  0.5346, -0.3827,  0.4665,  0.2465
 0.4889, -0.5572, -0.5718, -0.6021, -0.7150,  0.2163
-0.7782,  0.3491,  0.5996, -0.8389, -0.5366,  0.6516
-0.5847,  0.8347,  0.4226,  0.1078, -0.3910,  0.6134
 0.8469,  0.4121, -0.0439, -0.7476,  0.9521,  0.1571
-0.6803, -0.5948, -0.1376, -0.1916, -0.7065,  0.7156
 0.2878,  0.5086, -0.5785,  0.2019,  0.4979,  0.2980
 0.2764,  0.1943, -0.4090,  0.4632,  0.8906,  0.2960
-0.8877,  0.6705, -0.6155, -0.2098, -0.3998,  0.7107
-0.8398,  0.8093, -0.2597,  0.0614, -0.0118,  0.6502
-0.8476,  0.0158, -0.4769, -0.2859, -0.7839,  0.7715
 0.5751, -0.7868,  0.9714, -0.6457,  0.1448,  0.1175
 0.4802, -0.7001,  0.1022, -0.5668,  0.5184,  0.1090
 0.4458, -0.6469,  0.7239, -0.9604,  0.7205,  0.0779
 0.5175,  0.4339,  0.9747, -0.4438, -0.9924,  0.2879
 0.8678,  0.7158,  0.4577,  0.0334,  0.4139,  0.1678
 0.5406,  0.5012,  0.2264, -0.1963,  0.3946,  0.2088
-0.9938,  0.5498,  0.7928, -0.5214, -0.7585,  0.7687
 0.7661,  0.0863, -0.4266, -0.7233, -0.4197,  0.1466
 0.2277, -0.3517, -0.0853, -0.1118,  0.6563,  0.1767
 0.3499, -0.5570, -0.0655, -0.3705,  0.2537,  0.1632
 0.7547, -0.1046,  0.5689, -0.0861,  0.3125,  0.1257
 0.8186,  0.2110,  0.5335,  0.0094, -0.0039,  0.1391
 0.6858, -0.8644,  0.1465,  0.8855,  0.0357,  0.1845
-0.4967,  0.4015,  0.0805,  0.8977,  0.2487,  0.4663
 0.6760, -0.9841,  0.9787, -0.8446, -0.3557,  0.1509
-0.1203, -0.4885,  0.6054, -0.0443, -0.7313,  0.4854
 0.8557,  0.7919, -0.0169,  0.7134, -0.1628,  0.2002
 0.0115, -0.6209,  0.9300, -0.4116, -0.7931,  0.4052
-0.7114, -0.9718,  0.4319,  0.1290,  0.5892,  0.3661
 0.3915,  0.5557, -0.1870,  0.2955, -0.6404,  0.2954
-0.3564, -0.6548, -0.1827, -0.5172, -0.1862,  0.4622
 0.2392, -0.4959,  0.5857, -0.1341, -0.2850,  0.2470
-0.3394,  0.3947, -0.4627,  0.6166, -0.4094,  0.5325
 0.7107,  0.7768, -0.6312,  0.1707,  0.7964,  0.2757
-0.1078,  0.8437, -0.4420,  0.2177,  0.3649,  0.4028
-0.3139,  0.5595, -0.6505, -0.3161, -0.7108,  0.5546
 0.4335,  0.3986,  0.3770, -0.4932,  0.3847,  0.1810
-0.2562, -0.2894, -0.8847,  0.2633,  0.4146,  0.4036
 0.2272,  0.2966, -0.6601, -0.7011,  0.0284,  0.2778
-0.0743, -0.1421, -0.0054, -0.6770, -0.3151,  0.3597
-0.4762,  0.6891,  0.6007, -0.1467,  0.2140,  0.4266
-0.4061,  0.7193,  0.3432,  0.2669, -0.7505,  0.6147
-0.0588,  0.9731,  0.8966,  0.2902, -0.6966,  0.4955
-0.0627, -0.1439,  0.1985,  0.6999,  0.5022,  0.3077
 0.1587,  0.8494, -0.8705,  0.9827, -0.8940,  0.4263
-0.7850,  0.2473, -0.9040, -0.4308, -0.8779,  0.7199
 0.4070,  0.3369, -0.2428, -0.6236,  0.4940,  0.2215
-0.0242,  0.0513, -0.9430,  0.2885, -0.2987,  0.3947
-0.5416, -0.1322, -0.2351, -0.0604,  0.9590,  0.3683
 0.1055,  0.7783, -0.2901, -0.5090,  0.8220,  0.2984
-0.9129,  0.9015,  0.1128, -0.2473,  0.9901,  0.4776
-0.9378,  0.1424, -0.6391,  0.2619,  0.9618,  0.5368
 0.7498, -0.0963,  0.4169,  0.5549, -0.0103,  0.1614
-0.2612, -0.7156,  0.4538, -0.0460, -0.1022,  0.3717
 0.7720,  0.0552, -0.1818, -0.4622, -0.8560,  0.1685
-0.4177,  0.0070,  0.9319, -0.7812,  0.3461,  0.3052
-0.0001,  0.5542, -0.7128, -0.8336, -0.2016,  0.3803
 0.5356, -0.4194, -0.5662, -0.9666, -0.2027,  0.1776
-0.2378,  0.3187, -0.8582, -0.6948, -0.9668,  0.5474
-0.1947, -0.3579,  0.1158,  0.9869,  0.6690,  0.2992
 0.3992,  0.8365, -0.9205, -0.8593, -0.0520,  0.3154
-0.0209,  0.0793,  0.7905, -0.1067,  0.7541,  0.1864
-0.4928, -0.4524, -0.3433,  0.0951, -0.5597,  0.6261
-0.8118,  0.7404, -0.5263, -0.2280,  0.1431,  0.6349
 0.0516, -0.8480,  0.7483,  0.9023,  0.6250,  0.1959
-0.3212,  0.1093,  0.9488, -0.3766,  0.3376,  0.2735
-0.3481,  0.5490, -0.3484,  0.7797,  0.5034,  0.4379
-0.5785, -0.9170, -0.3563, -0.9258,  0.3877,  0.4121
 0.3407, -0.1391,  0.5356,  0.0720, -0.9203,  0.3458
-0.3287, -0.8954,  0.2102,  0.0241,  0.2349,  0.3247
-0.1353,  0.6954, -0.0919, -0.9692,  0.7461,  0.3338
 0.9036, -0.8982, -0.5299, -0.8733, -0.1567,  0.1187
 0.7277, -0.8368, -0.0538, -0.7489,  0.5458,  0.0830
 0.9049,  0.8878,  0.2279,  0.9470, -0.3103,  0.2194
 0.7957, -0.1308, -0.5284,  0.8817,  0.3684,  0.2172
 0.4647, -0.4931,  0.2010,  0.6292, -0.8918,  0.3371
-0.7390,  0.6849,  0.2367,  0.0626, -0.5034,  0.7039
-0.1567, -0.8711,  0.7940, -0.5932,  0.6525,  0.1710
 0.7635, -0.0265,  0.1969,  0.0545,  0.2496,  0.1445
 0.7675,  0.1354, -0.7698, -0.5460,  0.1920,  0.1728
-0.5211, -0.7372, -0.6763,  0.6897,  0.2044,  0.5217
 0.1913,  0.1980,  0.2314, -0.8816,  0.5006,  0.1998
 0.8964,  0.0694, -0.6149,  0.5059, -0.9854,  0.1825
 0.1767,  0.7104,  0.2093,  0.6452,  0.7590,  0.2832
-0.3580, -0.7541,  0.4426, -0.1193, -0.7465,  0.5657
-0.5996,  0.5766, -0.9758, -0.3933, -0.9572,  0.6800
 0.9950,  0.1641, -0.4132,  0.8579,  0.0142,  0.2003
-0.4717, -0.3894, -0.2567, -0.5111,  0.1691,  0.4266
 0.3917, -0.8561,  0.9422,  0.5061,  0.6123,  0.1212
-0.0366, -0.1087,  0.3449, -0.1025,  0.4086,  0.2475
 0.3633,  0.3943,  0.2372, -0.6980,  0.5216,  0.1925
-0.5325, -0.6466, -0.2178, -0.3589,  0.6310,  0.3568
 0.2271,  0.5200, -0.1447, -0.8011, -0.7699,  0.3128
 0.6415,  0.1993,  0.3777, -0.0178, -0.8237,  0.2181
-0.5298, -0.0768, -0.6028, -0.9490,  0.4588,  0.4356
 0.6870, -0.1431,  0.7294,  0.3141,  0.1621,  0.1632
-0.5985,  0.0591,  0.7889, -0.3900,  0.7419,  0.2945
 0.3661,  0.7984, -0.8486,  0.7572, -0.6183,  0.3449
 0.6995,  0.3342, -0.3113, -0.6972,  0.2707,  0.1712
 0.2565,  0.9126,  0.1798, -0.6043, -0.1413,  0.2893
-0.3265,  0.9839, -0.2395,  0.9854,  0.0376,  0.4770
 0.2690, -0.1722,  0.9818,  0.8599, -0.7015,  0.3954
-0.2102, -0.0768,  0.1219,  0.5607, -0.0256,  0.3949
 0.8216, -0.9555,  0.6422, -0.6231,  0.3715,  0.0801
-0.2896,  0.9484, -0.7545, -0.6249,  0.7789,  0.4370
-0.9985, -0.5448, -0.7092, -0.5931,  0.7926,  0.5402

Test data:

# synthetic_test_40.txt
#
 0.7462,  0.4006, -0.0590,  0.6543, -0.0083,  0.1935
 0.8495, -0.2260, -0.0142, -0.4911,  0.7699,  0.1078
-0.2335, -0.4049,  0.4352, -0.6183, -0.7636,  0.5088
 0.1810, -0.5142,  0.2465,  0.2767, -0.3449,  0.3136
-0.8650,  0.7611, -0.0801,  0.5277, -0.4922,  0.7140
-0.2358, -0.7466, -0.5115, -0.8413, -0.3943,  0.4533
 0.4834,  0.2300,  0.3448, -0.9832,  0.3568,  0.1360
-0.6502, -0.6300,  0.6885,  0.9652,  0.8275,  0.3046
-0.3053,  0.5604,  0.0929,  0.6329, -0.0325,  0.4756
-0.7995,  0.0740, -0.2680,  0.2086,  0.9176,  0.4565
-0.2144, -0.2141,  0.5813,  0.2902, -0.2122,  0.4119
-0.7278, -0.0987, -0.3312, -0.5641,  0.8515,  0.4438
 0.3793,  0.1976,  0.4933,  0.0839,  0.4011,  0.1905
-0.8568,  0.9573, -0.5272,  0.3212, -0.8207,  0.7415
-0.5785,  0.0056, -0.7901, -0.2223,  0.0760,  0.5551
 0.0735, -0.2188,  0.3925,  0.3570,  0.3746,  0.2191
 0.1230, -0.2838,  0.2262,  0.8715,  0.1938,  0.2878
 0.4792, -0.9248,  0.5295,  0.0366, -0.9894,  0.3149
-0.4456,  0.0697,  0.5359, -0.8938,  0.0981,  0.3879
 0.8629, -0.8505, -0.4464,  0.8385,  0.5300,  0.1769
 0.1995,  0.6659,  0.7921,  0.9454,  0.9970,  0.2330
-0.0249, -0.3066, -0.2927, -0.4923,  0.8220,  0.2437
 0.4513, -0.9481, -0.0770, -0.4374, -0.9421,  0.2879
-0.3405,  0.5931, -0.3507, -0.3842,  0.8562,  0.3987
 0.9538,  0.0471,  0.9039,  0.7760,  0.0361,  0.1706
-0.0887,  0.2104,  0.9808,  0.5478, -0.3314,  0.4128
-0.8220, -0.6302,  0.0537, -0.1658,  0.6013,  0.4306
-0.4123, -0.2880,  0.9074, -0.0461, -0.4435,  0.5144
 0.0060,  0.2867, -0.7775,  0.5161,  0.7039,  0.3599
-0.7968, -0.5484,  0.9426, -0.4308,  0.8148,  0.2979
 0.7811,  0.8450, -0.6877,  0.7594,  0.2640,  0.2362
-0.6802, -0.1113, -0.8325, -0.6694, -0.6056,  0.6544
 0.3821,  0.1476,  0.7466, -0.5107,  0.2592,  0.1648
 0.7265,  0.9683, -0.9803, -0.4943, -0.5523,  0.2454
-0.9049, -0.9797, -0.0196, -0.9090, -0.4433,  0.6447
-0.4607,  0.1811, -0.2389,  0.4050, -0.0078,  0.5229
 0.2664, -0.2932, -0.4259, -0.7336,  0.8742,  0.1834
-0.4507,  0.1029, -0.6294, -0.1158, -0.6294,  0.6081
 0.8948, -0.0124,  0.9278,  0.2899, -0.0314,  0.1534
-0.1323, -0.8813, -0.0146, -0.0697,  0.6135,  0.2386
This entry was posted in Machine Learning. Bookmark the permalink.

Leave a Reply