Decision Tree Regression From-Scratch C#, No Pointers, No Recursion

Decision tree regression is a machine learning technique that incorporates a set of if-then rules in a tree data structure to predict a single numeric value. For example, a decision tree regression model prediction might be, “If employee age is greater than 39.0 and age is less than or equal to 42.5 and years-experience is less than or equal to 8.0 and height is greater than 64.0 then bank account balance is $754.14.”

This blog post presents decision tree regression, implemented from scratch with C#, using list storage for tree nodes (no pointers/references), a FIFO stack to build the tree (no recursion), weighted variance minimization for the node split function, and saves associated row indices in each node.

The demo data looks like:

 0.1660,  0.4406, -0.9998, -0.3953, -0.7065,  0.4840
 0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.1568
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192,  0.8054
 0.9365, -0.3732,  0.3846,  0.7528,  0.7892,  0.1345
. . .

The first five values on each line are the x predictors. The last value on each line is the target y variable to predict. The data is synthetic and was generated by a 5-10-1 neural network with random weights and bias. Therefore, the data has an underlying structure which can be predicted. There are 200 training items and 40 test items.

The key parts of the output of the demo program (in slightly edited form to save space) are:

Begin decision tree regression

Loading synthetic train (200) and test (40) data
Done

First three train X:
 -0.1660  0.4406 -0.9998 -0.3953 -0.7065
  0.0776 -0.1616  0.3704 -0.5911  0.7562
 -0.9452  0.3409 -0.1654  0.1174 -0.7192

First three train y:
  0.4840
  0.1568
  0.8054

Setting maxDepth = 3
Setting minSamples = 2
Setting minLeaf = 18
Using default numSplitCols = -1

Creating and training tree
Done

Tree:
ID = 0  | col =  0 | -0.2102 | left =  1 | right =  2 | val = 0.3493 | F
ID = 1  | col =  4 |  0.1431 | left =  3 | right =  4 | val = 0.5345 | F
ID = 2  | col =  0 |  0.3915 | left =  5 | right =  6 | val = 0.2382 | F
ID = 3  | col =  0 | -0.6553 | left =  7 | right =  8 | val = 0.6358 | F
ID = 4  | col = -1 |  0.0000 | left =  9 | right = 10 | val = 0.4123 | T
ID = 5  | col =  4 | -0.2987 | left = 11 | right = 12 | val = 0.3032 | F
ID = 6  | col =  2 |  0.3777 | left = 13 | right = 14 | val = 0.1701 | F
ID = 7  | col = -1 |  0.0000 | left = 15 | right = 16 | val = 0.6952 | T
ID = 8  | col = -1 |  0.0000 | left = 17 | right = 18 | val = 0.5598 | T
ID = 11 | col = -1 |  0.0000 | left = 23 | right = 24 | val = 0.4101 | T
ID = 12 | col = -1 |  0.0000 | left = 25 | right = 26 | val = 0.2613 | T
ID = 13 | col = -1 |  0.0000 | left = 27 | right = 28 | val = 0.1882 | T
ID = 14 | col = -1 |  0.0000 | left = 29 | right = 30 | val = 0.1381 | T 

Evaluating model
Accuracy train (within 0.10) = 0.3750
Accuracy test (within 0.10) = 0.4750

MSE train = 0.0048
MSE test = 0.0054

Predicting for trainX[0] =
  -0.1660   0.4406  -0.9998  -0.3953  -0.7065
Predicted y = 0.4101

IF
column 0  >   -0.2102 AND
column 0  <=   0.3915 AND
column 4  <=  -0.2987 AND
THEN
predicted = 0.4101

End demo

If you compare this output with the visual representation, most of the ideas of decision tree regression will become mostly clear:


Click to enlarge.

The demo decision tree has maxDepth = 3. This creates a tree with 15 nodes (some of which could be empty). In general, if a tree has maxDepth = n, then it has 2^(n+1) – 1 nodes. These tree nodes could be stored as C# program-defined Node objects and connected via pointers/references. Instead, the architecture used by the demo decision tree is to create a List collection with size/count = 15, where each item in the List is a Node object.

With this design, if the root node is at index [0], then the left child is at [1] and the right child is at [2]. In general, if a node is at index [i], then the left child is at index [2*i+1] and the right child is at index [2*i+2].

The tree splitting part of constructing a decision tree regression system is perhaps best understood by looking at a concrete example. The root node of the demo decision tree is associated with all 200 rows of the training data. The splitting process selects some of the 200 rows to be assigned to the left child node, and the remaining rows to be assigned to the right child node.

For simplicity, suppose that instead of the demo training data with 200 rows and 5 columns of predictors, a tree node is associated with just 5 rows of data, and the data has just 3 columns of predictors:

     X                    y
[0] 0.99  0.22  0.77     3.0
[1] 0.55  0.44  0.00     9.0
[2] 0.88  0.66  0.88     7.0
[3] 0.11  0.33  0.22     1.0
[4] 0.55  0.88  0.33     5.0
     [0]   [1]   [2]

The split algorithm will select some of the five rows to go to the left child and the remaining rows to go to the right child. The idea is to select the two sets of rows in a way so that their associated y values are close together. The demo implementation does this by finding the split that minimizes the weighted variance of the target y values.

The splitting algorithm examines every possible x value in every column. For each possible candidate split value (also called the threshold), the weighted variance of the y values of the resulting split is computed. For example, for x = 0.33 in column [1], the rows where the values in column [1] are less than or equal to 0.33 are rows [3] and [0] with associated y values (1.0, 3.0). The other rows are [1], [2], [4] with associated y values (9.0, 7.0, 5.0).

The variance of a set of y values is the average of the sum of the squared differences between the values and their mean (average).

For this proposed split, the variance of y values (1.0, 3.0) is computed as 1.00. The mean of (1.0, 3.0) = (1.0 + 3.0) / 2 = 2.0 and so the variance is ((3.0 – 2.0)^2 + (1.0 – 2.0)^2) / 2 = 1.00.

The mean of y values (9.0, 7.0, 5.0) = (9.0 + 7.0 + 5.0) / 3 = 7.0 and the variance is ((9.0 – 7.0)^2 + (7.0 -7.0)^2 + (5.0 – 7.0)^2) / 3 = 2.67.

Because the proposed split has different numbers of rows (two rows for the left child and three rows for the right child), the weighted variance of the proposed split is used and is ((2 * 1.00) + (3 * 2.67)) / 5 = 2.00.

This process is repeated for every x value in every column. The x value and its column that generate the smallest weighted variance define the split.

The demo does not use recursion to build the tree. Instead, the demo uses a stack data structure. In high-level pseudo-code:

create empty stack
push (root, depth)
while stack not empty
  pop (currNode, currDepth)
  if at maxDepth or not enough data then
    node is a leaf, predicted = avg of associated Y values
  compute split value and split column using the
  associated rows stored in currNode
  if unable to split then
    node is a leaf, predicted = avg of associated Y values
  compute left-rows, right-rows
  create and push (leftNode, currDepth+1)
  create and push (rightNode, currDepth+1)   
end-while

The build-tree algorithm is short but very tricky. The stack holds a tree node and its associated rows of the training data, and the current depth of the tree. The current depth of the tree is maintained so that the algorithm knows when maxDepth has been reached.

Decision trees are rarely used by themselves because they usually overfit the training data, and predict poorly on new, previously unseen data. Multiple decision trees are usually combined into an ensemble technique such as bagging trees (“bootstrap aggregation”), random forest regression, adaptive boosting regression, and gradient boosting regression.



Decision trees are not plants.

“Utopia Zukunftsromane” (and several spinoffs) is a German series of over 600 inexpensive science fiction novels that were published weekly between 1953 and 1968. Several of the novels featured bad plants.

Left: Issue #103 (October 14, 1957). Center: Issue #206 (January 12, 1960). Right: Issue #337 (July 10, 1962).

There is some contradictory information about the publication data of Issue #206 in online databases so the date listed here might not be correct, but it’s certainly close.

The thing on the cover of Issue 337 might not be a plant, but t looks like one to me. I was unable to find an online version of the text of the novel so I wasn’t able to find out.


Demo code. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols (my blog editor chokes on symbols).

using System;
using System.IO;
using System.Collections.Generic;

// full (not-sparse) List storage with indexes (no pointers)
// stack-based construction (no recursion)
// the Nodes store associated row idxs 

namespace DecisionTreeRegression
{
  internal class DecisionTreeRegressionProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin decision tree regression ");

      // 1. load data
      Console.WriteLine("\nLoading synthetic train" +
        " (200) and test (40) data");
      string trainFile = "..\\..\\..\\Data\\" +
        "synthetic_train_200.txt";
      int[] colsX = new int[] { 0, 1, 2, 3, 4 };
      double[][] trainX =
        MatLoad(trainFile, colsX, ',', "#");
      double[] trainY =
        MatToVec(MatLoad(trainFile,
        new int[] { 5 }, ',', "#"));

      string testFile = "..\\..\\..\\Data\\" +
        "synthetic_test_40.txt";
      double[][] testX =
        MatLoad(testFile, colsX, ',', "#");
      double[] testY =
        MatToVec(MatLoad(testFile,
        new int[] { 5 }, ',', "#"));
      Console.WriteLine("Done ");

      Console.WriteLine("\nFirst three train X: ");
      for (int i = 0; i "lt" 3; ++i)
        VecShow(trainX[i], 4, 8);

      Console.WriteLine("\nFirst three train y: ");
      for (int i = 0; i "lt" 3; ++i)
        Console.WriteLine(trainY[i].ToString("F4").
          PadLeft(8));

      // 2. create and train/build tree
      int maxDepth = 3;
      int minSamples = 2;
      int minLeaf = 18;
      int numSplitCols = -1;  // all columns

      //int maxDepth = 3; //
      //int minSamples = 18;
      //int minLeaf = 1;
      //int numSplitCols = -1;  // all columns

      Console.WriteLine("\nSetting maxDepth = " +
        maxDepth);
      Console.WriteLine("Setting minSamples = " +
        minSamples);
      Console.WriteLine("Setting minLeaf = " +
        minLeaf);
      Console.WriteLine("Using default numSplitCols = -1 ");

      Console.WriteLine("\nCreating and training tree ");
      DecisionTreeRegressor dtr =
        new DecisionTreeRegressor(maxDepth, minSamples,
        minLeaf, numSplitCols, seed: 0);
      dtr.Train(trainX, trainY);
      Console.WriteLine("Done ");

      Console.WriteLine("\nTree: ");
      dtr.Display(showRows: false);
     
      // 3. evaluate model
      Console.WriteLine("\nEvaluating model ");
      double accTrain = dtr.Accuracy(trainX, trainY, 0.10);
      Console.WriteLine("Accuracy train (within 0.10) = " +
        accTrain.ToString("F4"));
      double accTest = dtr.Accuracy(testX, testY, 0.10);
      Console.WriteLine("Accuracy test (within 0.10) = " +
        accTest.ToString("F4"));

      double mseTrain = dtr.MSE(trainX, trainY);
      Console.WriteLine("\nMSE train = " +
        mseTrain.ToString("F4"));
      double mseTest = dtr.MSE(testX, testY);
      Console.WriteLine("MSE test = " +
        mseTest.ToString("F4"));

      // 4. use model
      Console.WriteLine("\nPredicting for trainX[0] = ");
      double[] x = trainX[0];
      VecShow(x, 4, 9);
      double predY = dtr.Predict(x);
      Console.WriteLine("Predicted y = " +
        predY.ToString("F4"));

      dtr.Explain(x);

      Console.WriteLine("\nEnd demo ");
      Console.ReadLine();

    } // Main()

    // ------------------------------------------------------
    // helpers for Main():
    //   MatLoad(), MatToVec(), VecShow().
    // ------------------------------------------------------

    static double[][] MatLoad(string fn, int[] usecols,
      char sep, string comment)
    {
      List"lt"double[]"gt" result =
        new List"lt"double[]"gt"();
      string line = "";
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        string[] tokens = line.Split(sep);
        List"lt"double"gt" lst = new List"lt"double"gt"();
        for (int j = 0; j "lt" usecols.Length; ++j)
          lst.Add(double.Parse(tokens[usecols[j]]));
        double[] row = lst.ToArray();
        result.Add(row);
      }
      sr.Close(); ifs.Close();
      return result.ToArray();
    }

    static double[] MatToVec(double[][] mat)
    {
      int nRows = mat.Length;
      int nCols = mat[0].Length;
      double[] result = new double[nRows * nCols];
      int k = 0;
      for (int i = 0; i "lt" nRows; ++i)
        for (int j = 0; j "lt" nCols; ++j)
          result[k++] = mat[i][j];
      return result;
    }

    static void VecShow(double[] vec, int dec, int wid)
    {
      for (int i = 0; i "lt" vec.Length; ++i)
        Console.Write(vec[i].ToString("F" + dec).
          PadLeft(wid));
      Console.WriteLine("");
    }

  } // class Program

  // ========================================================

  public class DecisionTreeRegressor
  {
    public int maxDepth;
    public int minSamples;  // aka min_samples_split
    public int minLeaf;  // min number of values in a leaf
    public int numSplitCols; // mostly for random forest
    public List"lt"Node"gt" tree = new List"lt"Node"gt"();
    public Random rnd;  // order in which cols are searched

    public double[][] trainX;  // store data by ref
    public double[] trainY;

    // ------------------------------------------------------

    public class Node
    {
      public int id;
      public int colIdx;  // aka featureIdx
      public double thresh;
      public int left;  // index into List
      public int right;
      public double value;
      public bool isLeaf;
      public List"lt"int"gt" rows;  // assoc rows in trainX

      public Node(int id, int colIdx, double thresh,
        int left, int right, double value, bool isLeaf,
        List"lt"int"gt" rows)
      {
        this.id = id;
        this.colIdx = colIdx;
        this.thresh = thresh;
        this.left = left;
        this.right = right;
        this.value = value;
        this.isLeaf = isLeaf;

        this.rows = rows;
      }
    } // class Node

    // --------------------------------------------

    public class StackInfo  // used to build tree
    {
      // tree node + associated rows, not explicit data
      public Node node; // node holds associated rows
      public int depth;

      public StackInfo(Node n, int d)
      {
        this.node = n;
        this.depth = d;
      }
    }

    // --------------------------------------------

    public DecisionTreeRegressor(int maxDepth = 2,
      int minSamples = 2, int minLeaf = 1,
      int numSplitCols = -1, int seed = 0)
    {
      // if maxDepth = 0, tree has just a root node
      // if maxDepth = 1, at most 3 nodes (root, l, r)
      // if maxDepth = n, at most 2^(n+1) - 1 nodes
      this.maxDepth = maxDepth;
      this.minSamples = minSamples;
      this.minLeaf = minLeaf;
      this.numSplitCols = numSplitCols;  // for ran. forest

      // create full tree List with dummy nodes
      int numNodes = (int)Math.Pow(2, (maxDepth + 1)) - 1;
      for (int i = 0; i "lt" numNodes; ++i)
      {
        //this.tree.Add(null); // OK, but let's avoid ptrs
        Node n = new Node(i, -1, 0.0, -1, -1, 0.0,
          false, new List"lt"int"gt"());
        this.tree.Add(n); // dummy node
      }
      this.rnd = new Random(seed);
    }

    // ------------------------------------------------------
    // public: Train(), Predict(), Explain(), Accuracy(),
    //   MSE(), Display().
    // helpers: MakeTree(), BestSplit(), TreeTargetMean(),
    //   TreeTargetVariance().
    // ------------------------------------------------------

    public void Train(double[][] trainX, double[] trainY)
    {
      this.trainX = trainX;
      this.trainY = trainY;
      this.MakeTree();
    }

    // ------------------------------------------------------

    public double Predict(double[] x)
    {
      int p = 0;
      Node currNode = this.tree[p];
      while (currNode.isLeaf == false && 
        p "lt" this.tree.Count)
      {
        if (x[currNode.colIdx] "lte" currNode.thresh)
          p = currNode.left;
        else
          p = currNode.right;
        currNode = this.tree[p];
      }
      return this.tree[p].value;

    }

    // ------------------------------------------------------

    public void Explain(double[] x)
    {
      int p = 0;
      Console.WriteLine("\nIF ");
      Node currNode = this.tree[p];
      while (currNode.isLeaf == false)
      {
        Console.Write("column ");
        Console.Write(currNode.colIdx + " ");
        if (x[currNode.colIdx] "lte" currNode.thresh)
        {
          Console.Write(" "lte" ");
          Console.WriteLine(currNode.thresh.
            ToString("F4").PadLeft(8) + " AND ");
          p = currNode.left;
          currNode = this.tree[p];
        }
        else
        {
          Console.Write(" "gt"  ");
          Console.WriteLine(currNode.thresh.
            ToString("F4").PadLeft(8) + " AND ");
          p = currNode.right;
          currNode = this.tree[p];
        }
      }
      Console.WriteLine("THEN \npredicted = " +
        currNode.value.ToString("F4"));
    }

    // ------------------------------------------------------

    public double Accuracy(double[][] dataX, double[] dataY,
      double pctClose)
    {
      int numCorrect = 0; int numWrong = 0;
      for (int i = 0; i "lt" dataX.Length; ++i)
      {
        double actualY = dataY[i];
        double predY = this.Predict(dataX[i]);
        if (Math.Abs(predY - actualY) "lt"
          Math.Abs(pctClose * actualY))
          ++numCorrect;
        else
          ++numWrong;
      }
      return (numCorrect * 1.0) / (numWrong + numCorrect);
    }

    // ------------------------------------------------------

    public double MSE(double[][] dataX,
      double[] dataY)
    {
      // standard machine learning MSE
      int n = dataX.Length;
      double sum = 0.0;
      for (int i = 0; i "lt" n; ++i)
      {
        double actualY = dataY[i];
        double predY = this.Predict(dataX[i]);
        sum += (actualY - predY) * (actualY - predY);
      }
      return sum / n;
    }

    // ------------------------------------------------------

    public void Display(bool showRows)
    {
      for (int i = 0; i "lt" this.tree.Count; ++i)
      {
        Node n = this.tree[i];
        // check for empty nodes
        if (n == null) continue;
        if (n.colIdx == -1 && n.isLeaf == false) continue;

        string s1 = "ID " +
          n.id.ToString().PadRight(3) + " | ";
        string s2 = "idx " +
          n.colIdx.ToString().PadLeft(3) + " | ";
        string s3 = "v " +
          n.thresh.ToString("F4").PadLeft(8) + " | ";
        string s4 = "L " +
          n.left.ToString().PadLeft(3) + " | ";
        string s5 = "R " +
          n.right.ToString().PadLeft(3) + " | ";
        string s6 = "v " +
          n.value.ToString("F4").PadLeft(8) + " | ";
        string s7 = "leaf " + (n.isLeaf == true ? "T" : "F");
          //n.isLeaf.ToString().PadLeft(6);

        string s8;
        if (showRows == true)
        {
          s8 = "\nRows:";
          for (int j = 0; j "lt" n.rows.Count; ++j)
          {
            if (j "gt" 0 && j % 20 == 0)
              s8 += "\n";
            s8 += n.rows[j].ToString().PadLeft(4) + " ";
          }
        }
        else
        {
          s8 = "";
        }
        Console.WriteLine(s1 + s2 + s3 + s4 +
          s5 + s6 + s7 + s8);
        if (showRows == true) Console.WriteLine("\n");
      }
    } // Display()

    // ------------------------------------------------------

    // helpers: MakeTree(), BestSplit(),
    // TreeTargetMean(), TreeTargetVariance()

    private void MakeTree()
    {
      // no recursion, no pointers, List storage
      if (this.numSplitCols == -1) // use all cols
        this.numSplitCols = this.trainX[0].Length;

      List"lt"int"gt" allRows = new List"lt"int"gt"();
      for (int i = 0; i "lt" this.trainX.Length; ++i)
        allRows.Add(i);
      double grandMean = this.TreeTargetMean(allRows);

      Node root = new Node(0, -1, 0.0, 1, 2,
        grandMean, false, allRows);

      Stack"lt"StackInfo"gt" stack = 
        new Stack"lt"StackInfo"gt"();
      stack.Push(new StackInfo(root, 0));
      while (stack.Count "gt" 0)
      {
        StackInfo info = stack.Pop();
        Node currNode = info.node;
        this.tree[currNode.id] = currNode;

        List"lt"int"gt" currRows = info.node.rows;
        int currDepth = info.depth;

        if (currDepth == this.maxDepth ||
          currRows.Count "lt" this.minSamples)
        {
          currNode.value = this.TreeTargetMean(currRows);
          currNode.isLeaf = true;
          continue;
        }

        double[] splitInfo = this.BestSplit(currRows);
        int colIdx = (int)splitInfo[0];
        double thresh = splitInfo[1];

        if (colIdx == -1)  // unable to split
        {
          currNode.value = this.TreeTargetMean(currRows);
          currNode.isLeaf = true;
          continue;
        }

        // got good split so at internal, non-leaf node
        currNode.colIdx = colIdx;
        currNode.thresh = thresh;

        // make the data splits for children
        // find left rows and right rows
        List"lt"int"gt" leftIdxs = new List"lt"int"gt"();
        List"lt"int"gt" rightIdxs = new List"lt"int"gt"();

        for (int k = 0; k "lt" currRows.Count; ++k)
        {
          int r = currRows[k];
          if (this.trainX[r][colIdx] "lte" thresh)
            leftIdxs.Add(r);
          else
            rightIdxs.Add(r);
        }

        int leftID = currNode.id * 2 + 1;
        Node currNodeLeft = new Node(leftID, -1, 0.0,
          2 * leftID + 1, 2 * leftID + 2,
          this.TreeTargetMean(leftIdxs), false, leftIdxs);
        stack.Push(new StackInfo(currNodeLeft,
          currDepth + 1));

        int rightID = currNode.id * 2 + 2;
        Node currNodeRight = new Node(rightID, -1, 0.0,
          2 * rightID + 1, 2 * rightID + 2,
          this.TreeTargetMean(rightIdxs), false, rightIdxs);
        stack.Push(new StackInfo(currNodeRight,
          currDepth + 1));

      } // while

      return;

    } // MakeTree()

    // ------------------------------------------------------

    private double[] BestSplit(List"lt"int"gt" rows)
    {
      // implicit params numSplitCols, minLeaf, numsplitCols
      // result[0] = best col idx (as double)
      // result[1] = best split value
      rows.Sort();

      int bestColIdx = -1;  // indicates bad split
      double bestThresh = 0.0;
      double bestVar = double.MaxValue;  // smaller better

      int nRows = rows.Count;  // or dataY.Length
      int nCols = this.trainX[0].Length;

      if (nRows == 0)
      {
        throw new Exception("empty data in BestSplit()");
      }

      // process cols in scrambled order
      int[] colIndices = new int[nCols];
      for (int k = 0; k "lt" nCols; ++k)
        colIndices[k] = k;
      // shuffle, inline Fisher-Yates
      int n = colIndices.Length;
      for (int i = 0; i "lt" n; ++i)
      {
        int ri = rnd.Next(i, n);  // be careful
        int tmp = colIndices[i];
        colIndices[i] = colIndices[ri];
        colIndices[ri] = tmp;
      }

      // numSplitCols is usually all columns (-1)
      for (int j = 0; j "lt" this.numSplitCols; ++j)
      {
        int colIdx = colIndices[j];
        HashSet"lt"double"gt" examineds = 
          new HashSet"lt"double"gt"();

        for (int i = 0; i "lt" nRows; ++i) // each row
        {
          // if curr thresh been seen, skip it
          double thresh = this.trainX[rows[i]][colIdx];
          if (examineds.Contains(thresh)) continue;
          examineds.Add(thresh);

          // get row idxs where x is lte, gt thresh
          List"lt"int"gt" leftIdxs = new List"lt"int"gt"();
          List"lt"int"gt" rightIdxs = new List"lt"int"gt"();
          for (int k = 0; k "lt" nRows; ++k)
          {
            if (this.trainX[rows[k]][colIdx] "lte" thresh)
              leftIdxs.Add(rows[k]);
            else
              rightIdxs.Add(rows[k]);
          }

          // Check if proposed split has too few values
          if (leftIdxs.Count "lt" this.minLeaf ||
            rightIdxs.Count "lt" this.minLeaf)
            continue;  // to next row

          double leftVar = 
            this.TreeTargetVariance(leftIdxs);
          double rightVar = 
            this.TreeTargetVariance(rightIdxs);
          double weightedVar = (leftIdxs.Count * leftVar +
            rightIdxs.Count * rightVar) / nRows;

          if (weightedVar "lt" bestVar)
          {
            bestColIdx = colIdx;
            bestThresh = thresh;
            bestVar = weightedVar;
          }

        } // each row
      } // j each col

      double[] result = new double[2];  // out params ugly
      result[0] = 1.0 * bestColIdx;
      result[1] = bestThresh;
      return result;

    } // BestSplit()

    // ------------------------------------------------------

    private double TreeTargetMean(List"lt"int"gt" rows)
    {
      // mean of rows items in trainY: for node prediction
      double sum = 0.0;
      for (int i = 0; i "lt" rows.Count; ++i)
      {
        int r = rows[i];
        sum += this.trainY[r];
      }
      return sum / rows.Count;
    }

    // ------------------------------------------------------

    private double TreeTargetVariance(List"lt"int"gt" rows)
    {
      double mean = this.TreeTargetMean(rows);
      double sum = 0.0;
      for (int i = 0; i "lt" rows.Count; ++i)
      {
        int r = rows[i];
        sum += (this.trainY[r] - mean) *
          (this.trainY[r] - mean);
      }
      return sum / rows.Count;
    }

    // ------------------------------------------------------

  } // class DecisionTreeRegressor

} // ns

Training data:

# synthetic_train_200.txt
#
-0.1660,  0.4406, -0.9998, -0.3953, -0.7065,  0.4840
 0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.1568
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192,  0.8054
 0.9365, -0.3732,  0.3846,  0.7528,  0.7892,  0.1345
-0.8299, -0.9219, -0.6603,  0.7563, -0.8033,  0.7955
 0.0663,  0.3838, -0.3690,  0.3730,  0.6693,  0.3206
-0.9634,  0.5003,  0.9777,  0.4963, -0.4391,  0.7377
-0.1042,  0.8172, -0.4128, -0.4244, -0.7399,  0.4801
-0.9613,  0.3577, -0.5767, -0.4689, -0.0169,  0.6861
-0.7065,  0.1786,  0.3995, -0.7953, -0.1719,  0.5569
 0.3888, -0.1716, -0.9001,  0.0718,  0.3276,  0.2500
 0.1731,  0.8068, -0.7251, -0.7214,  0.6148,  0.3297
-0.2046, -0.6693,  0.8550, -0.3045,  0.5016,  0.2129
 0.2473,  0.5019, -0.3022, -0.4601,  0.7918,  0.2613
-0.1438,  0.9297,  0.3269,  0.2434, -0.7705,  0.5171
 0.1568, -0.1837, -0.5259,  0.8068,  0.1474,  0.3307
-0.9943,  0.2343, -0.3467,  0.0541,  0.7719,  0.5581
 0.2467, -0.9684,  0.8589,  0.3818,  0.9946,  0.1092
-0.6553, -0.7257,  0.8652,  0.3936, -0.8680,  0.7018
 0.8460,  0.4230, -0.7515, -0.9602, -0.9476,  0.1996
-0.9434, -0.5076,  0.7201,  0.0777,  0.1056,  0.5664
 0.9392,  0.1221, -0.9627,  0.6013, -0.5341,  0.1533
 0.6142, -0.2243,  0.7271,  0.4942,  0.1125,  0.1661
 0.4260,  0.1194, -0.9749, -0.8561,  0.9346,  0.2230
 0.1362, -0.5934, -0.4953,  0.4877, -0.6091,  0.3810
 0.6937, -0.5203, -0.0125,  0.2399,  0.6580,  0.1460
-0.6864, -0.9628, -0.8600, -0.0273,  0.2127,  0.5387
 0.9772,  0.1595, -0.2397,  0.1019,  0.4907,  0.1611
 0.3385, -0.4702, -0.8673, -0.2598,  0.2594,  0.2270
-0.8669, -0.4794,  0.6095, -0.6131,  0.2789,  0.4700
 0.0493,  0.8496, -0.4734, -0.8681,  0.4701,  0.3516
 0.8639, -0.9721, -0.5313,  0.2336,  0.8980,  0.1412
 0.9004,  0.1133,  0.8312,  0.2831, -0.2200,  0.1782
 0.0991,  0.8524,  0.8375, -0.2102,  0.9265,  0.2150
-0.6521, -0.7473, -0.7298,  0.0113, -0.9570,  0.7422
 0.6190, -0.3105,  0.8802,  0.1640,  0.7577,  0.1056
 0.6895,  0.8108, -0.0802,  0.0927,  0.5972,  0.2214
 0.1982, -0.9689,  0.1870, -0.1326,  0.6147,  0.1310
-0.3695,  0.7858,  0.1557, -0.6320,  0.5759,  0.3773
-0.1596,  0.3581,  0.8372, -0.9992,  0.9535,  0.2071
-0.2468,  0.9476,  0.2094,  0.6577,  0.1494,  0.4132
 0.1737,  0.5000,  0.7166,  0.5102,  0.3961,  0.2611
 0.7290, -0.3546,  0.3416, -0.0983, -0.2358,  0.1332
-0.3652,  0.2438, -0.1395,  0.9476,  0.3556,  0.4170
-0.6029, -0.1466, -0.3133,  0.5953,  0.7600,  0.4334
-0.4596, -0.4953,  0.7098,  0.0554,  0.6043,  0.2775
 0.1450,  0.4663,  0.0380,  0.5418,  0.1377,  0.2931
-0.8636, -0.2442, -0.8407,  0.9656, -0.6368,  0.7429
 0.6237,  0.7499,  0.3768,  0.1390, -0.6781,  0.2185
-0.5499,  0.1850, -0.3755,  0.8326,  0.8193,  0.4399
-0.4858, -0.7782, -0.6141, -0.0008,  0.4572,  0.4197
 0.7033, -0.1683,  0.2334, -0.5327, -0.7961,  0.1776
 0.0317, -0.0457, -0.6947,  0.2436,  0.0880,  0.3345
 0.5031, -0.5559,  0.0387,  0.5706, -0.9553,  0.3107
-0.3513,  0.7458,  0.6894,  0.0769,  0.7332,  0.3170
 0.2205,  0.5992, -0.9309,  0.5405,  0.4635,  0.3532
-0.4806, -0.4859,  0.2646, -0.3094,  0.5932,  0.3202
 0.9809, -0.3995, -0.7140,  0.8026,  0.0831,  0.1600
 0.9495,  0.2732,  0.9878,  0.0921,  0.0529,  0.1289
-0.9476, -0.6792,  0.4913, -0.9392, -0.2669,  0.5966
 0.7247,  0.3854,  0.3819, -0.6227, -0.1162,  0.1550
-0.5922, -0.5045, -0.4757,  0.5003, -0.0860,  0.5863
-0.8861,  0.0170, -0.5761,  0.5972, -0.4053,  0.7301
 0.6877, -0.2380,  0.4997,  0.0223,  0.0819,  0.1404
 0.9189,  0.6079, -0.9354,  0.4188, -0.0700,  0.1907
-0.1428, -0.7820,  0.2676,  0.6059,  0.3936,  0.2790
 0.5324, -0.3151,  0.6917, -0.1425,  0.6480,  0.1071
-0.8432, -0.9633, -0.8666, -0.0828, -0.7733,  0.7784
-0.9444,  0.5097, -0.2103,  0.4939, -0.0952,  0.6787
-0.0520,  0.6063, -0.1952,  0.8094, -0.9259,  0.4836
 0.5477, -0.7487,  0.2370, -0.9793,  0.0773,  0.1241
 0.2450,  0.8116,  0.9799,  0.4222,  0.4636,  0.2355
 0.8186, -0.1983, -0.5003, -0.6531, -0.7611,  0.1511
-0.4714,  0.6382, -0.3788,  0.9648, -0.4667,  0.5950
 0.0673, -0.3711,  0.8215, -0.2669, -0.1328,  0.2677
-0.9381,  0.4338,  0.7820, -0.9454,  0.0441,  0.5518
-0.3480,  0.7190,  0.1170,  0.3805, -0.0943,  0.4724
-0.9813,  0.1535, -0.3771,  0.0345,  0.8328,  0.5438
-0.1471, -0.5052, -0.2574,  0.8637,  0.8737,  0.3042
-0.5454, -0.3712, -0.6505,  0.2142, -0.1728,  0.5783
 0.6327, -0.6297,  0.4038, -0.5193,  0.1484,  0.1153
-0.5424,  0.3282, -0.0055,  0.0380, -0.6506,  0.6613
 0.1414,  0.9935,  0.6337,  0.1887,  0.9520,  0.2540
-0.9351, -0.8128, -0.8693, -0.0965, -0.2491,  0.7353
 0.9507, -0.6640,  0.9456,  0.5349,  0.6485,  0.1059
-0.0462, -0.9737, -0.2940, -0.0159,  0.4602,  0.2606
-0.0627, -0.0852, -0.7247, -0.9782,  0.5166,  0.2977
 0.0478,  0.5098, -0.0723, -0.7504, -0.3750,  0.3335
 0.0090,  0.3477,  0.5403, -0.7393, -0.9542,  0.4415
-0.9748,  0.3449,  0.3736, -0.1015,  0.8296,  0.4358
 0.2887, -0.9895, -0.0311,  0.7186,  0.6608,  0.2057
 0.1570, -0.4518,  0.1211,  0.3435, -0.2951,  0.3244
 0.7117, -0.6099,  0.4946, -0.4208,  0.5476,  0.1096
-0.2929, -0.5726,  0.5346, -0.3827,  0.4665,  0.2465
 0.4889, -0.5572, -0.5718, -0.6021, -0.7150,  0.2163
-0.7782,  0.3491,  0.5996, -0.8389, -0.5366,  0.6516
-0.5847,  0.8347,  0.4226,  0.1078, -0.3910,  0.6134
 0.8469,  0.4121, -0.0439, -0.7476,  0.9521,  0.1571
-0.6803, -0.5948, -0.1376, -0.1916, -0.7065,  0.7156
 0.2878,  0.5086, -0.5785,  0.2019,  0.4979,  0.2980
 0.2764,  0.1943, -0.4090,  0.4632,  0.8906,  0.2960
-0.8877,  0.6705, -0.6155, -0.2098, -0.3998,  0.7107
-0.8398,  0.8093, -0.2597,  0.0614, -0.0118,  0.6502
-0.8476,  0.0158, -0.4769, -0.2859, -0.7839,  0.7715
 0.5751, -0.7868,  0.9714, -0.6457,  0.1448,  0.1175
 0.4802, -0.7001,  0.1022, -0.5668,  0.5184,  0.1090
 0.4458, -0.6469,  0.7239, -0.9604,  0.7205,  0.0779
 0.5175,  0.4339,  0.9747, -0.4438, -0.9924,  0.2879
 0.8678,  0.7158,  0.4577,  0.0334,  0.4139,  0.1678
 0.5406,  0.5012,  0.2264, -0.1963,  0.3946,  0.2088
-0.9938,  0.5498,  0.7928, -0.5214, -0.7585,  0.7687
 0.7661,  0.0863, -0.4266, -0.7233, -0.4197,  0.1466
 0.2277, -0.3517, -0.0853, -0.1118,  0.6563,  0.1767
 0.3499, -0.5570, -0.0655, -0.3705,  0.2537,  0.1632
 0.7547, -0.1046,  0.5689, -0.0861,  0.3125,  0.1257
 0.8186,  0.2110,  0.5335,  0.0094, -0.0039,  0.1391
 0.6858, -0.8644,  0.1465,  0.8855,  0.0357,  0.1845
-0.4967,  0.4015,  0.0805,  0.8977,  0.2487,  0.4663
 0.6760, -0.9841,  0.9787, -0.8446, -0.3557,  0.1509
-0.1203, -0.4885,  0.6054, -0.0443, -0.7313,  0.4854
 0.8557,  0.7919, -0.0169,  0.7134, -0.1628,  0.2002
 0.0115, -0.6209,  0.9300, -0.4116, -0.7931,  0.4052
-0.7114, -0.9718,  0.4319,  0.1290,  0.5892,  0.3661
 0.3915,  0.5557, -0.1870,  0.2955, -0.6404,  0.2954
-0.3564, -0.6548, -0.1827, -0.5172, -0.1862,  0.4622
 0.2392, -0.4959,  0.5857, -0.1341, -0.2850,  0.2470
-0.3394,  0.3947, -0.4627,  0.6166, -0.4094,  0.5325
 0.7107,  0.7768, -0.6312,  0.1707,  0.7964,  0.2757
-0.1078,  0.8437, -0.4420,  0.2177,  0.3649,  0.4028
-0.3139,  0.5595, -0.6505, -0.3161, -0.7108,  0.5546
 0.4335,  0.3986,  0.3770, -0.4932,  0.3847,  0.1810
-0.2562, -0.2894, -0.8847,  0.2633,  0.4146,  0.4036
 0.2272,  0.2966, -0.6601, -0.7011,  0.0284,  0.2778
-0.0743, -0.1421, -0.0054, -0.6770, -0.3151,  0.3597
-0.4762,  0.6891,  0.6007, -0.1467,  0.2140,  0.4266
-0.4061,  0.7193,  0.3432,  0.2669, -0.7505,  0.6147
-0.0588,  0.9731,  0.8966,  0.2902, -0.6966,  0.4955
-0.0627, -0.1439,  0.1985,  0.6999,  0.5022,  0.3077
 0.1587,  0.8494, -0.8705,  0.9827, -0.8940,  0.4263
-0.7850,  0.2473, -0.9040, -0.4308, -0.8779,  0.7199
 0.4070,  0.3369, -0.2428, -0.6236,  0.4940,  0.2215
-0.0242,  0.0513, -0.9430,  0.2885, -0.2987,  0.3947
-0.5416, -0.1322, -0.2351, -0.0604,  0.9590,  0.3683
 0.1055,  0.7783, -0.2901, -0.5090,  0.8220,  0.2984
-0.9129,  0.9015,  0.1128, -0.2473,  0.9901,  0.4776
-0.9378,  0.1424, -0.6391,  0.2619,  0.9618,  0.5368
 0.7498, -0.0963,  0.4169,  0.5549, -0.0103,  0.1614
-0.2612, -0.7156,  0.4538, -0.0460, -0.1022,  0.3717
 0.7720,  0.0552, -0.1818, -0.4622, -0.8560,  0.1685
-0.4177,  0.0070,  0.9319, -0.7812,  0.3461,  0.3052
-0.0001,  0.5542, -0.7128, -0.8336, -0.2016,  0.3803
 0.5356, -0.4194, -0.5662, -0.9666, -0.2027,  0.1776
-0.2378,  0.3187, -0.8582, -0.6948, -0.9668,  0.5474
-0.1947, -0.3579,  0.1158,  0.9869,  0.6690,  0.2992
 0.3992,  0.8365, -0.9205, -0.8593, -0.0520,  0.3154
-0.0209,  0.0793,  0.7905, -0.1067,  0.7541,  0.1864
-0.4928, -0.4524, -0.3433,  0.0951, -0.5597,  0.6261
-0.8118,  0.7404, -0.5263, -0.2280,  0.1431,  0.6349
 0.0516, -0.8480,  0.7483,  0.9023,  0.6250,  0.1959
-0.3212,  0.1093,  0.9488, -0.3766,  0.3376,  0.2735
-0.3481,  0.5490, -0.3484,  0.7797,  0.5034,  0.4379
-0.5785, -0.9170, -0.3563, -0.9258,  0.3877,  0.4121
 0.3407, -0.1391,  0.5356,  0.0720, -0.9203,  0.3458
-0.3287, -0.8954,  0.2102,  0.0241,  0.2349,  0.3247
-0.1353,  0.6954, -0.0919, -0.9692,  0.7461,  0.3338
 0.9036, -0.8982, -0.5299, -0.8733, -0.1567,  0.1187
 0.7277, -0.8368, -0.0538, -0.7489,  0.5458,  0.0830
 0.9049,  0.8878,  0.2279,  0.9470, -0.3103,  0.2194
 0.7957, -0.1308, -0.5284,  0.8817,  0.3684,  0.2172
 0.4647, -0.4931,  0.2010,  0.6292, -0.8918,  0.3371
-0.7390,  0.6849,  0.2367,  0.0626, -0.5034,  0.7039
-0.1567, -0.8711,  0.7940, -0.5932,  0.6525,  0.1710
 0.7635, -0.0265,  0.1969,  0.0545,  0.2496,  0.1445
 0.7675,  0.1354, -0.7698, -0.5460,  0.1920,  0.1728
-0.5211, -0.7372, -0.6763,  0.6897,  0.2044,  0.5217
 0.1913,  0.1980,  0.2314, -0.8816,  0.5006,  0.1998
 0.8964,  0.0694, -0.6149,  0.5059, -0.9854,  0.1825
 0.1767,  0.7104,  0.2093,  0.6452,  0.7590,  0.2832
-0.3580, -0.7541,  0.4426, -0.1193, -0.7465,  0.5657
-0.5996,  0.5766, -0.9758, -0.3933, -0.9572,  0.6800
 0.9950,  0.1641, -0.4132,  0.8579,  0.0142,  0.2003
-0.4717, -0.3894, -0.2567, -0.5111,  0.1691,  0.4266
 0.3917, -0.8561,  0.9422,  0.5061,  0.6123,  0.1212
-0.0366, -0.1087,  0.3449, -0.1025,  0.4086,  0.2475
 0.3633,  0.3943,  0.2372, -0.6980,  0.5216,  0.1925
-0.5325, -0.6466, -0.2178, -0.3589,  0.6310,  0.3568
 0.2271,  0.5200, -0.1447, -0.8011, -0.7699,  0.3128
 0.6415,  0.1993,  0.3777, -0.0178, -0.8237,  0.2181
-0.5298, -0.0768, -0.6028, -0.9490,  0.4588,  0.4356
 0.6870, -0.1431,  0.7294,  0.3141,  0.1621,  0.1632
-0.5985,  0.0591,  0.7889, -0.3900,  0.7419,  0.2945
 0.3661,  0.7984, -0.8486,  0.7572, -0.6183,  0.3449
 0.6995,  0.3342, -0.3113, -0.6972,  0.2707,  0.1712
 0.2565,  0.9126,  0.1798, -0.6043, -0.1413,  0.2893
-0.3265,  0.9839, -0.2395,  0.9854,  0.0376,  0.4770
 0.2690, -0.1722,  0.9818,  0.8599, -0.7015,  0.3954
-0.2102, -0.0768,  0.1219,  0.5607, -0.0256,  0.3949
 0.8216, -0.9555,  0.6422, -0.6231,  0.3715,  0.0801
-0.2896,  0.9484, -0.7545, -0.6249,  0.7789,  0.4370
-0.9985, -0.5448, -0.7092, -0.5931,  0.7926,  0.5402

Test data:

# synthetic_test_40.txt
#
 0.7462,  0.4006, -0.0590,  0.6543, -0.0083,  0.1935
 0.8495, -0.2260, -0.0142, -0.4911,  0.7699,  0.1078
-0.2335, -0.4049,  0.4352, -0.6183, -0.7636,  0.5088
 0.1810, -0.5142,  0.2465,  0.2767, -0.3449,  0.3136
-0.8650,  0.7611, -0.0801,  0.5277, -0.4922,  0.7140
-0.2358, -0.7466, -0.5115, -0.8413, -0.3943,  0.4533
 0.4834,  0.2300,  0.3448, -0.9832,  0.3568,  0.1360
-0.6502, -0.6300,  0.6885,  0.9652,  0.8275,  0.3046
-0.3053,  0.5604,  0.0929,  0.6329, -0.0325,  0.4756
-0.7995,  0.0740, -0.2680,  0.2086,  0.9176,  0.4565
-0.2144, -0.2141,  0.5813,  0.2902, -0.2122,  0.4119
-0.7278, -0.0987, -0.3312, -0.5641,  0.8515,  0.4438
 0.3793,  0.1976,  0.4933,  0.0839,  0.4011,  0.1905
-0.8568,  0.9573, -0.5272,  0.3212, -0.8207,  0.7415
-0.5785,  0.0056, -0.7901, -0.2223,  0.0760,  0.5551
 0.0735, -0.2188,  0.3925,  0.3570,  0.3746,  0.2191
 0.1230, -0.2838,  0.2262,  0.8715,  0.1938,  0.2878
 0.4792, -0.9248,  0.5295,  0.0366, -0.9894,  0.3149
-0.4456,  0.0697,  0.5359, -0.8938,  0.0981,  0.3879
 0.8629, -0.8505, -0.4464,  0.8385,  0.5300,  0.1769
 0.1995,  0.6659,  0.7921,  0.9454,  0.9970,  0.2330
-0.0249, -0.3066, -0.2927, -0.4923,  0.8220,  0.2437
 0.4513, -0.9481, -0.0770, -0.4374, -0.9421,  0.2879
-0.3405,  0.5931, -0.3507, -0.3842,  0.8562,  0.3987
 0.9538,  0.0471,  0.9039,  0.7760,  0.0361,  0.1706
-0.0887,  0.2104,  0.9808,  0.5478, -0.3314,  0.4128
-0.8220, -0.6302,  0.0537, -0.1658,  0.6013,  0.4306
-0.4123, -0.2880,  0.9074, -0.0461, -0.4435,  0.5144
 0.0060,  0.2867, -0.7775,  0.5161,  0.7039,  0.3599
-0.7968, -0.5484,  0.9426, -0.4308,  0.8148,  0.2979
 0.7811,  0.8450, -0.6877,  0.7594,  0.2640,  0.2362
-0.6802, -0.1113, -0.8325, -0.6694, -0.6056,  0.6544
 0.3821,  0.1476,  0.7466, -0.5107,  0.2592,  0.1648
 0.7265,  0.9683, -0.9803, -0.4943, -0.5523,  0.2454
-0.9049, -0.9797, -0.0196, -0.9090, -0.4433,  0.6447
-0.4607,  0.1811, -0.2389,  0.4050, -0.0078,  0.5229
 0.2664, -0.2932, -0.4259, -0.7336,  0.8742,  0.1834
-0.4507,  0.1029, -0.6294, -0.1158, -0.6294,  0.6081
 0.8948, -0.0124,  0.9278,  0.2899, -0.0314,  0.1534
-0.1323, -0.8813, -0.0146, -0.0697,  0.6135,  0.2386
This entry was posted in Machine Learning. Bookmark the permalink.

Leave a Reply