Decision Tree Regression (No Recursion, No Pointers) From Scratch Using C#

The goal of a machine learning regression system is to predict a single numeric value. For example, you might want to predict the bank account balance of a person based on age, annual salary, height, and years of education.

Decision tree regression is a technique that stores a set of if-then rules in a tree data structure. For example, a decision tree regression model prediction might have a form like, “If age is greater than 43.0 and age is less than or equal to 51.5 and years-experience is less than or equal to 20.0 and height is greater than 58.0 then bank account balance is $845.41.”

There are many ways to design a decision tree regression system. Three major design decisions are 1.) to use pointers/references for tree nodes or not, 2.) to use recursion or not when constructing the tree, 3.) choice of splitting algorithm (typically mean squared error minimization, or explicit variance reduction, or mean absolute deviation minimization). My standard implementation is: use pointers/references, no recursion, mean squared error minimization.

Decision trees are rarely used by themselves because they usually overfit the training data severely. But a collection/ensemble of simple decision trees is used for bagging tree regression, random forest regression, adaptive boosting regression, and gradient boosting regression. In ensemble techniques, using pointers/references is fine. However, in some scenarios, avoiding the use of pointers/references is needed. If you don’t use pointers, then the decision tree must be stored explicitly in an array or list.

I put together a demo of a basic decision tree regression system that uses no pointers (so I used a List for storage), no recursion (so I used a stack data structure to build the tree), and mean squared error minimization (which is equivalent to variance reduction).

For my demo, I used synthetic data that looks like:

-0.1660,  0.4406, -0.9998, -0.3953, -0.7065,  0.4840
 0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.1568
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192,  0.8054
. . .

The first five values on each line are predictor values. The last value is the target y value to predict. The data was generated by a 5-10-1 neural network with random weights and bias values. There are 200 training items and 40 test items.

The key output parts of my demo decision tree regression program are:

Begin decision tree regression

Loading synthetic train (200) and test (40) data
Done

Setting maxDepth = 3
Setting minSamples = 2
Setting minLeaf = 18
Creating and training tree
Done

Tree:
ID 0      0  -0.2102     1     2   0.3493  False
ID 1      4   0.1431     3     4   0.0000  False
ID 2      0   0.3915     5     6   0.0000  False
ID 3      0  -0.6553     7     8   0.0000  False
ID 4     -1   0.0000     9    10   0.4123   True
ID 5      4  -0.2987    11    12   0.0000  False
ID 6      2   0.3777    13    14   0.0000  False
ID 7     -1   0.0000    15    16   0.6952   True
ID 8     -1   0.0000    17    18   0.5598   True
ID 9     -1   0.0000    -1    -1   0.0000  False
ID 10    -1   0.0000    -1    -1   0.0000  False
ID 11    -1   0.0000    23    24   0.4101   True
ID 12    -1   0.0000    25    26   0.2613   True
ID 13    -1   0.0000    27    28   0.1882   True
ID 14    -1   0.0000    29    30   0.1381   True

Evaluating model
Accuracy train (within 0.10) = 0.3750
Accuracy test (within 0.10) = 0.4750

MSE train = 0.0048
MSE test = 0.0054

Predicting for trainX[0] =
  -0.1660   0.4406  -0.9998  -0.3953  -0.7065
Predicted y = 0.4101

IF
column 0  gt   -0.2102 AND
column 0  lte   0.3915 AND
column 4  lte  -0.2987 AND
THEN
predicted = 0.4101

End demo

In the output above, the tree node fields are ID, split column, split value, left child index, right child index, prediction value, and is_leaf. Notice that the demo tree has 15 nodes, from 0 to 14, so left and right child indexes that are greater than 14 are equivalent to null pointers.

The three model parameters are maxDepth, minSamples, and minLeaf. All three control the size and shape of the decision tree. If maxDepth is set to 0, the tree has just a single root node and the prediction for any input is the average of the target y values in the training data. If maxDepth is set to 1, the tree has at most 3 nodes (root, left child, right child). If maxDepth is set to n, the tree has at most 2^(n+1)-1 nodes. The larger maxDepth is, the more accurate the decision tree will be on the training data at the expense of worse overfitting.


A diagram that illustrates the demo decision tree. Click to enlarge.

The minSamples value sets how many rows of data must be associated with a tree node before the node to is split into two child nodes. The minSamples value must be 2 or greater, and 2 is a good default value to use. The larger minSamples is, the less accurate the decision tree will be on the training data, but the tree will be less overfitted.

The minLeaf value value sets how many rows of data must be associated with a tree node after a split takes place. The minLeaf value must be 1 or greater, and 1 is a good default value to use. The demo uses 18 (out of 200 training data items) only because it gave a representative demo.

The maxDepth, minSamples, and minLeaf values interact in complex ways. In practice, tuning a decision tree regression system usually starts with setting minSamples = 2, and minLeaf = 1, and maxDepth = 3, and then examining what happens as maxDepth is increased.



Decision trees are not plants but they’re generally good. One of my all-time favorite novels is “The Gods of Mars” (1918) by Edgar Rice Burroughs. It is the second book of the Mars series, after “A Princess of Mars”.

John Carter, and his friend the 15-foot tall green Martian Tars Tarkas, discover an evil race of Therns who use vicious one-eyed Plant Men to capture or kill unsuspecting victims who travel on a false pilgrimage to the Valley of Dor.


Demo program. Replace “lt” (less than), “gt”, “lte”, “gte”, “and” with Boolean operator symbols (my blog editor frequently chokes on symbols).

using System;
using System.Collections.Generic;
using System.IO;

// full (not-sparse) List storage with indexes (no pointers)

namespace DecisionTreeRegression
{
  internal class DecisionTreeRegressionProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin decision tree regression ");

      // 1. load data
      Console.WriteLine("\nLoading synthetic train" +
        " (200) and test (40) data");
      string trainFile = "..\\..\\..\\Data\\" +
        "synthetic_train_200.txt";
      int[] colsX = new int[] { 0, 1, 2, 3, 4 };
      double[][] trainX =
        MatLoad(trainFile, colsX, ',', "#");
      double[] trainY =
        MatToVec(MatLoad(trainFile,
        new int[] { 5 }, ',', "#"));

      string testFile = "..\\..\\..\\Data\\" +
        "synthetic_test_40.txt";
      double[][] testX =
        MatLoad(testFile, colsX, ',', "#");
      double[] testY =
        MatToVec(MatLoad(testFile,
        new int[] { 5 }, ',', "#"));
      Console.WriteLine("Done ");

      Console.WriteLine("\nFirst three train X: ");
      for (int i = 0; i "lt" 3; ++i)
        VecShow(trainX[i], 4, 8);

      Console.WriteLine("\nFirst three train y: ");
      for (int i = 0; i "lt" 3; ++i)
        Console.WriteLine(trainY[i].ToString("F4").
          PadLeft(8));

      // 2. create and train/build tree
      int maxDepth = 3;
      int minSamples = 2;
      int minLeaf = 18;

      Console.WriteLine("\nSetting maxDepth = " +
        maxDepth);
      Console.WriteLine("Setting minSamples = " +
        minSamples);
      Console.WriteLine("Setting minLeaf = " +
        minLeaf);

      Console.WriteLine("Creating and training tree ");
      DecisionTreeRegressor tree =
        new DecisionTreeRegressor(maxDepth, minSamples,
        minLeaf, numSplitCols: -1, seed:0);
      tree.Train(trainX, trainY);
      Console.WriteLine("Done ");

      Console.WriteLine("\nTree: ");
      tree.Display();

      // 3. evaluate model
      Console.WriteLine("\nEvaluating model ");
      double accTrain = tree.Accuracy(trainX, trainY, 0.10);
      Console.WriteLine("Accuracy train (within 0.10) = " +
        accTrain.ToString("F4"));
      double accTest = tree.Accuracy(testX, testY, 0.10);
      Console.WriteLine("Accuracy test (within 0.10) = " +
        accTest.ToString("F4"));

      double mseTrain = tree.MSE(trainX, trainY);
      Console.WriteLine("\nMSE train = " +
        mseTrain.ToString("F4"));
      double mseTest = tree.MSE(testX, testY);
      Console.WriteLine("MSE test = " +
        mseTest.ToString("F4"));

      // 4. use model
      Console.WriteLine("\nPredicting for trainX[0] = ");
      double[] x = trainX[0];  
      VecShow(x, 4, 9);
      double predY = tree.Predict(x);
      Console.WriteLine("Predicted y = " +
        predY.ToString("F4"));

      tree.Explain(x);

      //// 5. Save() and Load()
      //Console.WriteLine("\nSaving model to file ");
      //string fn =
      //  "..\\..\\..\\Models\\tree_model_3_2_18.txt";
      //tree.Save(fn);
      //Console.WriteLine("Done ");

      //Console.WriteLine("\nLoading saved model to new tree ");
      //DecisionTreeRegressor tree2 =
      //  new DecisionTreeRegressor(maxDepth, minSamples);
      //tree2.Load(fn);
      //Console.WriteLine("Done ");

      //Console.WriteLine("\nNew tree: ");
      //tree2.Display();

      //Console.ReadLine();

      //Console.WriteLine("\nUsing saved model " +
      //  "to predict for: ");
      //VecShow(x, 4, 9);
      //predY = tree2.Predict(x);
      //Console.WriteLine("Predicted y = " +
      //  predY.ToString("F4"));

      ////tree2.Explain(x);

      Console.WriteLine("\nEnd demo ");
      Console.ReadLine();

    } // Main()

    // ------------------------------------------------------
    // helpers for Main()
    // ------------------------------------------------------

    static double[][] MatLoad(string fn, int[] usecols,
      char sep, string comment)
    {
      List"lt"double[]"gt" result =
        new List"lt"double[]"gt"();
      string line = "";
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        string[] tokens = line.Split(sep);
        List"lt"double"gt" lst = new List"lt"double"gt"();
        for (int j = 0; j "lt" usecols.Length; ++j)
          lst.Add(double.Parse(tokens[usecols[j]]));
        double[] row = lst.ToArray();
        result.Add(row);
      }
      sr.Close(); ifs.Close();
      return result.ToArray();
    }

    static double[] MatToVec(double[][] mat)
    {
      int nRows = mat.Length;
      int nCols = mat[0].Length;
      double[] result = new double[nRows * nCols];
      int k = 0;
      for (int i = 0; i "lt" nRows; ++i)
        for (int j = 0; j "lt" nCols; ++j)
          result[k++] = mat[i][j];
      return result;
    }

    static void VecShow(double[] vec, int dec, int wid)
    {
      for (int i = 0; i "lt" vec.Length; ++i)
        Console.Write(vec[i].ToString("F" + dec).
          PadLeft(wid));
      Console.WriteLine("");
    }

  } // class Program

  // ========================================================

  public class DecisionTreeRegressor
  {
    public int maxDepth;
    public int minSamples;  // aka min_samples_split
    public int minLeaf;  // min number of values in a leaf
    public int numSplitCols; // mostly for random forest
    public List"lt"Node"gt" tree = new List"lt"Node"gt"();
    public Random rnd;  // order in which cols are searched

    // ------------------------------------------------------

    public class Node
    {
      public int id;
      public int colIdx;  // aka featureIdx
      public double thresh;

      public int left;  // index into List
      public int right;

      public double value;
      public bool isLeaf;

      // ------------------------------------------

      public Node(int id, int colIdx, double thresh,
        int left, int right, double value, bool isLeaf)
      {
        this.id = id;
        this.colIdx = colIdx;
        this.thresh = thresh;
        this.left = left;
        this.right = right;
        this.value = value;
        this.isLeaf = isLeaf;
      }

      public void Show()
      {
        string s1 = "ID " + 
          this.id.ToString().PadRight(3) + "  ";
        string s2 = this.colIdx.ToString().PadLeft(3) + " ";
        string s3 = this.thresh.ToString("F4").
          PadLeft(8) + " ";
        string s4 = this.left.ToString().PadLeft(5) + " ";
        string s5 = this.right.ToString().PadLeft(5) + " ";
        string s6 = this.value.ToString("F4").
          PadLeft(8) + " ";
        string s7 = this.isLeaf.ToString().PadLeft(6);
        Console.WriteLine(s1 + s2 + s3 + s4 + s5 + s6 + s7);
      }
    } // class Node

    // --------------------------------------------

    public class StackInfo  // used to build tree
    {
      // tree node + associated rows
      public Node node;
      public double[][] dataX;
      public double[] dataY;
      public int depth;

      public StackInfo(Node n, double[][] X,
        double[] y, int d)
      {
        this.node = n; this.dataX = X;
        this.dataY = y; this.depth = d;
      }
    }

    // --------------------------------------------

    public DecisionTreeRegressor(int maxDepth = 2,
      int minSamples = 2, int minLeaf = 1,
      int numSplitCols = -1, int seed = 0)
    {
      // if maxDepth = 0, tree has just a root node
      // if maxDepth = 1, at most 3 nodes (root, l, r)
      // if maxDepth = n, at most 2^(n+1) - 1 nodes
      this.maxDepth = maxDepth;
      this.minSamples = minSamples;
      this.minLeaf = minLeaf;
      this.numSplitCols = numSplitCols;  // for ran. forest
 
      // create full tree List with dummy nodes
      int numNodes = (int)Math.Pow(2, (maxDepth + 1)) - 1;
      for (int i = 0; i "lt" numNodes; ++i)
      {
        Node n = new Node(i, -1, 0.0, -1, -1, 0.0, false);
        this.tree.Add(n);
      }
      this.rnd = new Random(seed);
    }

    // ------------------------------------------------------

    private double[] BestSplit(double[][] dataX,
      double[] dataY)
    {
      // implicit params numSplitCols, minLeaf
      // result[0] = best col idx (as double)
      // result[1] = best split value

      int bestColIdx = -1;  // indicates bad split
      double bestThresh = 0.0;
      double bestMSE = double.MaxValue;  // smaller is better
      int nRows = dataX.Length;  // or dataY.Length
      int nCols = dataX[0].Length;

      if (nRows == 0)
      {
        throw new Exception("empty data in BestSplit()");
      }

      int[] colIndices = new int[nCols];
      for (int k = 0; k "lt" nCols; ++k)
        colIndices[k] = k;
      Shuffle(colIndices, this.rnd);

      if (this.numSplitCols != -1)  // use only some cols
      {
        int[] cpyIndices = new int[nCols];
        for (int k = 0; k "lt" nCols; ++k)
          cpyIndices[k] = colIndices[k];
        colIndices = new int[this.numSplitCols];
        for (int k = 0; k "lt" this.numSplitCols; ++k)
          colIndices[k] = cpyIndices[k];
      }

      for (int j = 0; j "lt" colIndices.Length; ++j)
      {
        int colIdx = colIndices[j];
        HashSet"lt"double"gt" examineds = 
          new HashSet"lt"double"gt"();

        for (int i = 0; i "lt" nRows; ++i) // each row
        {
          // if curr threh been seen, skip it
          double thresh = dataX[i][colIdx];
          if (examineds.Contains(thresh)) continue;

          examineds.Add(thresh);

          // get rows where x is lte, gt thresh
          List"lt"int"gt" leftIdxs = new List"lt"int"gt"();
          List"lt"int"gt" rightIdxs = new List"lt"int"gt"();
          for (int r = 0; r "lt" nRows; ++r)
          {
            if (dataX[r][colIdx] "lte" thresh)
              leftIdxs.Add(r);
            else
              rightIdxs.Add(r);
          }

          // Check if proposed split would lead to an empty
          // node, which would cause Mean() to fail
          // when building the tree.
          // But allow a node with a single value.
          // if (leftIdxs.Count == 0 ||
          //   rightIdxs.Count == 0)
          //   continue; // scikit "min_samples_leaf=1"
          if (leftIdxs.Count "lt" this.minLeaf ||
            rightIdxs.Count "lt" this.minLeaf)
            continue; // scikit "min_samples_leaf=1"

          // get the left y values and right y values
          List"lt"double"gt" leftValues = 
            new List"lt"double"gt"();
          for (int k = 0; k "lt" leftIdxs.Count; ++k)
            leftValues.Add(dataY[leftIdxs[k]]);

          List"lt"double"gt" rightValues = 
            new List"lt"double"gt"();
          for (int k = 0; k "lt" rightIdxs.Count; ++k)
            rightValues.Add(dataY[rightIdxs[k]]);

          // compute MSE, equivalent to variance reduction
          double mseLeft = MSE(leftValues);
          double mseRight = MSE(rightValues);
          double splitMSE = (leftValues.Count * mseLeft +
            rightValues.Count * mseRight) / nRows;

          if (splitMSE "lt" bestMSE)
          {
            bestColIdx = colIdx;
            bestThresh = thresh;
            bestMSE = splitMSE;
          }

        } // each row
      } // j each col

      double[] result = new double[2];
      result[0] = 1.0 * bestColIdx;
      result[1] = bestThresh;
      return result;

    } // BestSplit()

    // ------------------------------------------------------

    private static void Shuffle(int[] indices, Random rnd)
    {
      int n = indices.Length;
      for (int i = 0; i "lt" n; ++i)
      {
        int ri = rnd.Next(i, n);
        int tmp = indices[i];
        indices[i] = indices[ri];
        indices[ri] = tmp;
      }
    }

    // ------------------------------------------------------

    private static double Mean(List"lt"double"gt" values)
    {
      double sum = 0.0;
      for (int i = 0; i "lt" values.Count; ++i)
        sum += values[i];
      return sum / values.Count;
    }

    // ------------------------------------------------------

    private static double MSE(List"lt"double"gt" values)
    {
      double mean = Mean(values);
      double sum = 0.0;
      for (int i = 0; i "lt" values.Count; ++i)
        sum += (values[i] - mean) * (values[i] - mean);
      return sum / values.Count;
    }

    // ------------------------------------------------------

    private void MakeTree(double[][] dataX,
      double[] dataY, int depth = 0)
    {
      // no recursion, no pointers, List storage

      Node root = new Node(0, -1, 0.0, 1, 2,
        Mean(dataY.ToList()), false);
      Stack"lt"StackInfo"gt" stack = 
        new Stack"lt"StackInfo"gt"();

      stack.Push(new StackInfo(root, dataX, dataY, 0));
      while (stack.Count "gt" 0)
      {
        StackInfo info = stack.Pop();
        Node currNode = info.node;
        this.tree[currNode.id] = currNode;

        double[][] currX = info.dataX;
        double[] currY = info.dataY;
        int currDepth = info.depth;

        if (currDepth == this.maxDepth ||
          currY.Length "lt" this.minSamples) {
          currNode.value = Mean(currY.ToList());
          currNode.isLeaf = true;
          continue;
        }

        double[] splitInfo = this.BestSplit(currX, currY);
        int colIdx = (int)splitInfo[0];
        double thresh = splitInfo[1];

        if (colIdx == -1)  // unable to split
        {
          currNode.value = Mean(currY.ToList());
          currNode.isLeaf = true;
          continue;
        }

        // got a good split so at an internal, non-leaf node
        currNode.colIdx = colIdx;
        currNode.thresh = thresh;

        // make the data splits for children
        List"lt"int"gt" leftIdxs = new List"lt"int"gt"();
        List"lt"int"gt" rightIdxs = new List"lt"int"gt"();
        for (int r = 0; r "lt" currX.Length; ++r)
        {
          if (currX[r][colIdx] "lte" thresh)
            leftIdxs.Add(r);
          else
            rightIdxs.Add(r);
        }

        double[][] leftX = 
          this.ExtractRows(currX, leftIdxs);
        double[] leftY = 
          this.ExtractVals(currY, leftIdxs);
        double[][] rightX = 
          this.ExtractRows(currX, rightIdxs);
        double[] rightY = 
          this.ExtractVals(currY, rightIdxs);

        int leftID = currNode.id * 2 + 1;
        Node currNodeLeft = new Node(leftID, -1, 0.0,
          2*leftID + 1, 2 * leftID + 2, 0.0, false);
        stack.Push(new StackInfo(currNodeLeft, leftX,
          leftY, currDepth + 1));

        int rightID = currNode.id * 2 + 2;
        Node currNodeRight = new Node(rightID, -1, 0.0,
          2*rightID + 1, 2 * rightID + 2, 0.0, false);
        stack.Push(new StackInfo(currNodeRight, rightX,
          rightY, currDepth + 1));
 
      } // while

      return;

    } // MakeTree()

    // ------------------------------------------------------

    private double[][] ExtractRows(double[][] source,
      List"lt"int"gt" rowIdxs)
    {
      // surprisingly tricky
      int numSrcRows = source.Length;
      int numSrcCols = source[0].Length;
      int numDestRows = rowIdxs.Count;
      int numDestCols = numSrcCols;

      double[][] result = new double[numDestRows][];
      for (int i = 0; i "lt" numDestRows; ++i)
        result[i] = new double[numDestCols];

      for (int i = 0; i "lt" rowIdxs.Count; ++i)
      {
        int srcRow = rowIdxs[i];
        for (int j = 0; j "lt" numDestCols; ++j)
          result[i][j] = source[srcRow][j];
      }
      return result;
    }

    // ------------------------------------------------------

    private double[] ExtractVals(double[] source,
      List"lt"int"gt" idxs)
    {
      double[] result = new double[idxs.Count];
      for (int i = 0; i "lt" idxs.Count; ++i)
      {
        int srcIdx = idxs[i];
        result[i] = source[srcIdx];
      }
      return result;
    }

    // ------------------------------------------------------

    public void Train(double[][] trainX, double[] trainY)
    {
      this.MakeTree(trainX, trainY);
    }

    // ------------------------------------------------------

    public double Predict(double[] x)
    {
      int p = 0;
      Node currNode = this.tree[p];
      while (currNode.isLeaf == false "and" 
        p "lt" this.tree.Count)
      {
        if (x[currNode.colIdx] "lte" currNode.thresh)
          p = currNode.left;
        else
          p = currNode.right;
        currNode = this.tree[p];
      }
      return this.tree[p].value;

    }

    // ------------------------------------------------------

    public void Explain(double[] x)
    {
      int p = 0;
      Console.WriteLine("\nIF ");
      Node currNode = this.tree[p];
      while (currNode.isLeaf == false)
      {
        Console.Write("column ");
        Console.Write(currNode.colIdx + " ");
        if (x[currNode.colIdx] "lte" currNode.thresh)
        {
          Console.Write(" "lte" ");
          Console.WriteLine(currNode.thresh.
            ToString("F4").PadLeft(8) + " AND ");
          //currNode = currNode.left;
          p = currNode.left;
          currNode = this.tree[p];
        }
        else
        {
          Console.Write(" "gt"  ");
          Console.WriteLine(currNode.thresh.
            ToString("F4").PadLeft(8) + " AND ");
          p = currNode.right;
          currNode = this.tree[p];
        }
      }
      Console.WriteLine("THEN \npredicted = " +
        currNode.value.ToString("F4"));
    }

    // ------------------------------------------------------

    public double Accuracy(double[][] dataX, double[] dataY,
      double pctClose)
    {
      int numCorrect = 0; int numWrong = 0;
      for (int i = 0; i "lt" dataX.Length; ++i)
      {
        double actualY = dataY[i];
        double predY = this.Predict(dataX[i]);
        if (Math.Abs(predY - actualY) "lt"
          Math.Abs(pctClose * actualY))
          ++numCorrect;
        else
          ++numWrong;
      }
      return (numCorrect * 1.0) / (numWrong + numCorrect);
    }

    // ------------------------------------------------------

    public double MSE(double[][] dataX,
      double[] dataY)
    {
      int n = dataX.Length;
      double sum = 0.0;
      for (int i = 0; i "lt" n; ++i)
      {
        double actualY = dataY[i];
        double predY = this.Predict(dataX[i]);
        sum += (actualY - predY) * (actualY - predY);
      }
      return sum / n;
    }

    // ------------------------------------------------------

    public void Display()
    {
      for (int i = 0; i "lt" this.tree.Count; ++i)
      {
        Node n = this.tree[i];
        n.Show();
      }
    }

    // ------------------------------------------------------

    public void Save(string fn)
    {
      // nodeID, colIdx, thresh, leftIdx, rightIdx,
      //   value, isLeaf
      // ex:
      // 0, 4, 0.1234, 1, 2, 0.9876, false

      FileStream ofs = new FileStream(fn, FileMode.Create);
      StreamWriter sw = new StreamWriter(ofs);

      // traverse tree
      for (int i = 0; i "lt" this.tree.Count; ++i)
      {
        Node n = this.tree[i];
        string s = "";
        s += n.id + ",";
        s += n.colIdx + ",";
        s += n.thresh.ToString("F4") + ",";
        s += n.left + ",";
        s += n.right + ",";
        s += n.value.ToString("F4") + ",";
        s += n.isLeaf.ToString();

        sw.WriteLine(s);
      }

      sw.Close(); ofs.Close();
      return;
    } // Save()

    public void Load(string fn)
    {
      // nodeID, colIdx, thresh, leftIdx, rightIdx,
      //   value, isLeaf
      // ex:
      // 0, 4, 0.1234, 1, 2, 0.9876, false
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);

      int maxNodes = (int)Math.Pow(2, this.maxDepth + 1) - 1;
      this.tree = 
        new List"lt"Node"gt"(); // zap away existing tree
      for (int i = 0; i "lt" maxNodes; ++i)
        this.tree.Add(new Node(-1, -1, 0.0,
          -1, -1, 0.0, false)); // dummy node

      string line = "";
      string[] tokens = null;
      while ((line = sr.ReadLine()) != null)
      {
        tokens = line.Split(',');
        int idx = int.Parse(tokens[0]); // where ?
        this.tree[idx].id = idx;
        this.tree[idx].colIdx = int.Parse(tokens[1]);
        this.tree[idx].thresh = double.Parse(tokens[2]);

        int leftIdx = int.Parse(tokens[3]);
        this.tree[idx].left = leftIdx;

        int rightIdx = int.Parse(tokens[4]);
        this.tree[idx].right = rightIdx;

        this.tree[idx].value = double.Parse(tokens[5]);
        this.tree[idx].isLeaf = bool.Parse(tokens[6]);
      }

      sr.Close(); ifs.Close();
      return;
    } // Load()

    // ------------------------------------------------------

  } // class DecisionTreeRegressor

} // ns

Training data:

# synthetic_train_200.txt
#
-0.1660,  0.4406, -0.9998, -0.3953, -0.7065,  0.4840
 0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.1568
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192,  0.8054
 0.9365, -0.3732,  0.3846,  0.7528,  0.7892,  0.1345
-0.8299, -0.9219, -0.6603,  0.7563, -0.8033,  0.7955
 0.0663,  0.3838, -0.3690,  0.3730,  0.6693,  0.3206
-0.9634,  0.5003,  0.9777,  0.4963, -0.4391,  0.7377
-0.1042,  0.8172, -0.4128, -0.4244, -0.7399,  0.4801
-0.9613,  0.3577, -0.5767, -0.4689, -0.0169,  0.6861
-0.7065,  0.1786,  0.3995, -0.7953, -0.1719,  0.5569
 0.3888, -0.1716, -0.9001,  0.0718,  0.3276,  0.2500
 0.1731,  0.8068, -0.7251, -0.7214,  0.6148,  0.3297
-0.2046, -0.6693,  0.8550, -0.3045,  0.5016,  0.2129
 0.2473,  0.5019, -0.3022, -0.4601,  0.7918,  0.2613
-0.1438,  0.9297,  0.3269,  0.2434, -0.7705,  0.5171
 0.1568, -0.1837, -0.5259,  0.8068,  0.1474,  0.3307
-0.9943,  0.2343, -0.3467,  0.0541,  0.7719,  0.5581
 0.2467, -0.9684,  0.8589,  0.3818,  0.9946,  0.1092
-0.6553, -0.7257,  0.8652,  0.3936, -0.8680,  0.7018
 0.8460,  0.4230, -0.7515, -0.9602, -0.9476,  0.1996
-0.9434, -0.5076,  0.7201,  0.0777,  0.1056,  0.5664
 0.9392,  0.1221, -0.9627,  0.6013, -0.5341,  0.1533
 0.6142, -0.2243,  0.7271,  0.4942,  0.1125,  0.1661
 0.4260,  0.1194, -0.9749, -0.8561,  0.9346,  0.2230
 0.1362, -0.5934, -0.4953,  0.4877, -0.6091,  0.3810
 0.6937, -0.5203, -0.0125,  0.2399,  0.6580,  0.1460
-0.6864, -0.9628, -0.8600, -0.0273,  0.2127,  0.5387
 0.9772,  0.1595, -0.2397,  0.1019,  0.4907,  0.1611
 0.3385, -0.4702, -0.8673, -0.2598,  0.2594,  0.2270
-0.8669, -0.4794,  0.6095, -0.6131,  0.2789,  0.4700
 0.0493,  0.8496, -0.4734, -0.8681,  0.4701,  0.3516
 0.8639, -0.9721, -0.5313,  0.2336,  0.8980,  0.1412
 0.9004,  0.1133,  0.8312,  0.2831, -0.2200,  0.1782
 0.0991,  0.8524,  0.8375, -0.2102,  0.9265,  0.2150
-0.6521, -0.7473, -0.7298,  0.0113, -0.9570,  0.7422
 0.6190, -0.3105,  0.8802,  0.1640,  0.7577,  0.1056
 0.6895,  0.8108, -0.0802,  0.0927,  0.5972,  0.2214
 0.1982, -0.9689,  0.1870, -0.1326,  0.6147,  0.1310
-0.3695,  0.7858,  0.1557, -0.6320,  0.5759,  0.3773
-0.1596,  0.3581,  0.8372, -0.9992,  0.9535,  0.2071
-0.2468,  0.9476,  0.2094,  0.6577,  0.1494,  0.4132
 0.1737,  0.5000,  0.7166,  0.5102,  0.3961,  0.2611
 0.7290, -0.3546,  0.3416, -0.0983, -0.2358,  0.1332
-0.3652,  0.2438, -0.1395,  0.9476,  0.3556,  0.4170
-0.6029, -0.1466, -0.3133,  0.5953,  0.7600,  0.4334
-0.4596, -0.4953,  0.7098,  0.0554,  0.6043,  0.2775
 0.1450,  0.4663,  0.0380,  0.5418,  0.1377,  0.2931
-0.8636, -0.2442, -0.8407,  0.9656, -0.6368,  0.7429
 0.6237,  0.7499,  0.3768,  0.1390, -0.6781,  0.2185
-0.5499,  0.1850, -0.3755,  0.8326,  0.8193,  0.4399
-0.4858, -0.7782, -0.6141, -0.0008,  0.4572,  0.4197
 0.7033, -0.1683,  0.2334, -0.5327, -0.7961,  0.1776
 0.0317, -0.0457, -0.6947,  0.2436,  0.0880,  0.3345
 0.5031, -0.5559,  0.0387,  0.5706, -0.9553,  0.3107
-0.3513,  0.7458,  0.6894,  0.0769,  0.7332,  0.3170
 0.2205,  0.5992, -0.9309,  0.5405,  0.4635,  0.3532
-0.4806, -0.4859,  0.2646, -0.3094,  0.5932,  0.3202
 0.9809, -0.3995, -0.7140,  0.8026,  0.0831,  0.1600
 0.9495,  0.2732,  0.9878,  0.0921,  0.0529,  0.1289
-0.9476, -0.6792,  0.4913, -0.9392, -0.2669,  0.5966
 0.7247,  0.3854,  0.3819, -0.6227, -0.1162,  0.1550
-0.5922, -0.5045, -0.4757,  0.5003, -0.0860,  0.5863
-0.8861,  0.0170, -0.5761,  0.5972, -0.4053,  0.7301
 0.6877, -0.2380,  0.4997,  0.0223,  0.0819,  0.1404
 0.9189,  0.6079, -0.9354,  0.4188, -0.0700,  0.1907
-0.1428, -0.7820,  0.2676,  0.6059,  0.3936,  0.2790
 0.5324, -0.3151,  0.6917, -0.1425,  0.6480,  0.1071
-0.8432, -0.9633, -0.8666, -0.0828, -0.7733,  0.7784
-0.9444,  0.5097, -0.2103,  0.4939, -0.0952,  0.6787
-0.0520,  0.6063, -0.1952,  0.8094, -0.9259,  0.4836
 0.5477, -0.7487,  0.2370, -0.9793,  0.0773,  0.1241
 0.2450,  0.8116,  0.9799,  0.4222,  0.4636,  0.2355
 0.8186, -0.1983, -0.5003, -0.6531, -0.7611,  0.1511
-0.4714,  0.6382, -0.3788,  0.9648, -0.4667,  0.5950
 0.0673, -0.3711,  0.8215, -0.2669, -0.1328,  0.2677
-0.9381,  0.4338,  0.7820, -0.9454,  0.0441,  0.5518
-0.3480,  0.7190,  0.1170,  0.3805, -0.0943,  0.4724
-0.9813,  0.1535, -0.3771,  0.0345,  0.8328,  0.5438
-0.1471, -0.5052, -0.2574,  0.8637,  0.8737,  0.3042
-0.5454, -0.3712, -0.6505,  0.2142, -0.1728,  0.5783
 0.6327, -0.6297,  0.4038, -0.5193,  0.1484,  0.1153
-0.5424,  0.3282, -0.0055,  0.0380, -0.6506,  0.6613
 0.1414,  0.9935,  0.6337,  0.1887,  0.9520,  0.2540
-0.9351, -0.8128, -0.8693, -0.0965, -0.2491,  0.7353
 0.9507, -0.6640,  0.9456,  0.5349,  0.6485,  0.1059
-0.0462, -0.9737, -0.2940, -0.0159,  0.4602,  0.2606
-0.0627, -0.0852, -0.7247, -0.9782,  0.5166,  0.2977
 0.0478,  0.5098, -0.0723, -0.7504, -0.3750,  0.3335
 0.0090,  0.3477,  0.5403, -0.7393, -0.9542,  0.4415
-0.9748,  0.3449,  0.3736, -0.1015,  0.8296,  0.4358
 0.2887, -0.9895, -0.0311,  0.7186,  0.6608,  0.2057
 0.1570, -0.4518,  0.1211,  0.3435, -0.2951,  0.3244
 0.7117, -0.6099,  0.4946, -0.4208,  0.5476,  0.1096
-0.2929, -0.5726,  0.5346, -0.3827,  0.4665,  0.2465
 0.4889, -0.5572, -0.5718, -0.6021, -0.7150,  0.2163
-0.7782,  0.3491,  0.5996, -0.8389, -0.5366,  0.6516
-0.5847,  0.8347,  0.4226,  0.1078, -0.3910,  0.6134
 0.8469,  0.4121, -0.0439, -0.7476,  0.9521,  0.1571
-0.6803, -0.5948, -0.1376, -0.1916, -0.7065,  0.7156
 0.2878,  0.5086, -0.5785,  0.2019,  0.4979,  0.2980
 0.2764,  0.1943, -0.4090,  0.4632,  0.8906,  0.2960
-0.8877,  0.6705, -0.6155, -0.2098, -0.3998,  0.7107
-0.8398,  0.8093, -0.2597,  0.0614, -0.0118,  0.6502
-0.8476,  0.0158, -0.4769, -0.2859, -0.7839,  0.7715
 0.5751, -0.7868,  0.9714, -0.6457,  0.1448,  0.1175
 0.4802, -0.7001,  0.1022, -0.5668,  0.5184,  0.1090
 0.4458, -0.6469,  0.7239, -0.9604,  0.7205,  0.0779
 0.5175,  0.4339,  0.9747, -0.4438, -0.9924,  0.2879
 0.8678,  0.7158,  0.4577,  0.0334,  0.4139,  0.1678
 0.5406,  0.5012,  0.2264, -0.1963,  0.3946,  0.2088
-0.9938,  0.5498,  0.7928, -0.5214, -0.7585,  0.7687
 0.7661,  0.0863, -0.4266, -0.7233, -0.4197,  0.1466
 0.2277, -0.3517, -0.0853, -0.1118,  0.6563,  0.1767
 0.3499, -0.5570, -0.0655, -0.3705,  0.2537,  0.1632
 0.7547, -0.1046,  0.5689, -0.0861,  0.3125,  0.1257
 0.8186,  0.2110,  0.5335,  0.0094, -0.0039,  0.1391
 0.6858, -0.8644,  0.1465,  0.8855,  0.0357,  0.1845
-0.4967,  0.4015,  0.0805,  0.8977,  0.2487,  0.4663
 0.6760, -0.9841,  0.9787, -0.8446, -0.3557,  0.1509
-0.1203, -0.4885,  0.6054, -0.0443, -0.7313,  0.4854
 0.8557,  0.7919, -0.0169,  0.7134, -0.1628,  0.2002
 0.0115, -0.6209,  0.9300, -0.4116, -0.7931,  0.4052
-0.7114, -0.9718,  0.4319,  0.1290,  0.5892,  0.3661
 0.3915,  0.5557, -0.1870,  0.2955, -0.6404,  0.2954
-0.3564, -0.6548, -0.1827, -0.5172, -0.1862,  0.4622
 0.2392, -0.4959,  0.5857, -0.1341, -0.2850,  0.2470
-0.3394,  0.3947, -0.4627,  0.6166, -0.4094,  0.5325
 0.7107,  0.7768, -0.6312,  0.1707,  0.7964,  0.2757
-0.1078,  0.8437, -0.4420,  0.2177,  0.3649,  0.4028
-0.3139,  0.5595, -0.6505, -0.3161, -0.7108,  0.5546
 0.4335,  0.3986,  0.3770, -0.4932,  0.3847,  0.1810
-0.2562, -0.2894, -0.8847,  0.2633,  0.4146,  0.4036
 0.2272,  0.2966, -0.6601, -0.7011,  0.0284,  0.2778
-0.0743, -0.1421, -0.0054, -0.6770, -0.3151,  0.3597
-0.4762,  0.6891,  0.6007, -0.1467,  0.2140,  0.4266
-0.4061,  0.7193,  0.3432,  0.2669, -0.7505,  0.6147
-0.0588,  0.9731,  0.8966,  0.2902, -0.6966,  0.4955
-0.0627, -0.1439,  0.1985,  0.6999,  0.5022,  0.3077
 0.1587,  0.8494, -0.8705,  0.9827, -0.8940,  0.4263
-0.7850,  0.2473, -0.9040, -0.4308, -0.8779,  0.7199
 0.4070,  0.3369, -0.2428, -0.6236,  0.4940,  0.2215
-0.0242,  0.0513, -0.9430,  0.2885, -0.2987,  0.3947
-0.5416, -0.1322, -0.2351, -0.0604,  0.9590,  0.3683
 0.1055,  0.7783, -0.2901, -0.5090,  0.8220,  0.2984
-0.9129,  0.9015,  0.1128, -0.2473,  0.9901,  0.4776
-0.9378,  0.1424, -0.6391,  0.2619,  0.9618,  0.5368
 0.7498, -0.0963,  0.4169,  0.5549, -0.0103,  0.1614
-0.2612, -0.7156,  0.4538, -0.0460, -0.1022,  0.3717
 0.7720,  0.0552, -0.1818, -0.4622, -0.8560,  0.1685
-0.4177,  0.0070,  0.9319, -0.7812,  0.3461,  0.3052
-0.0001,  0.5542, -0.7128, -0.8336, -0.2016,  0.3803
 0.5356, -0.4194, -0.5662, -0.9666, -0.2027,  0.1776
-0.2378,  0.3187, -0.8582, -0.6948, -0.9668,  0.5474
-0.1947, -0.3579,  0.1158,  0.9869,  0.6690,  0.2992
 0.3992,  0.8365, -0.9205, -0.8593, -0.0520,  0.3154
-0.0209,  0.0793,  0.7905, -0.1067,  0.7541,  0.1864
-0.4928, -0.4524, -0.3433,  0.0951, -0.5597,  0.6261
-0.8118,  0.7404, -0.5263, -0.2280,  0.1431,  0.6349
 0.0516, -0.8480,  0.7483,  0.9023,  0.6250,  0.1959
-0.3212,  0.1093,  0.9488, -0.3766,  0.3376,  0.2735
-0.3481,  0.5490, -0.3484,  0.7797,  0.5034,  0.4379
-0.5785, -0.9170, -0.3563, -0.9258,  0.3877,  0.4121
 0.3407, -0.1391,  0.5356,  0.0720, -0.9203,  0.3458
-0.3287, -0.8954,  0.2102,  0.0241,  0.2349,  0.3247
-0.1353,  0.6954, -0.0919, -0.9692,  0.7461,  0.3338
 0.9036, -0.8982, -0.5299, -0.8733, -0.1567,  0.1187
 0.7277, -0.8368, -0.0538, -0.7489,  0.5458,  0.0830
 0.9049,  0.8878,  0.2279,  0.9470, -0.3103,  0.2194
 0.7957, -0.1308, -0.5284,  0.8817,  0.3684,  0.2172
 0.4647, -0.4931,  0.2010,  0.6292, -0.8918,  0.3371
-0.7390,  0.6849,  0.2367,  0.0626, -0.5034,  0.7039
-0.1567, -0.8711,  0.7940, -0.5932,  0.6525,  0.1710
 0.7635, -0.0265,  0.1969,  0.0545,  0.2496,  0.1445
 0.7675,  0.1354, -0.7698, -0.5460,  0.1920,  0.1728
-0.5211, -0.7372, -0.6763,  0.6897,  0.2044,  0.5217
 0.1913,  0.1980,  0.2314, -0.8816,  0.5006,  0.1998
 0.8964,  0.0694, -0.6149,  0.5059, -0.9854,  0.1825
 0.1767,  0.7104,  0.2093,  0.6452,  0.7590,  0.2832
-0.3580, -0.7541,  0.4426, -0.1193, -0.7465,  0.5657
-0.5996,  0.5766, -0.9758, -0.3933, -0.9572,  0.6800
 0.9950,  0.1641, -0.4132,  0.8579,  0.0142,  0.2003
-0.4717, -0.3894, -0.2567, -0.5111,  0.1691,  0.4266
 0.3917, -0.8561,  0.9422,  0.5061,  0.6123,  0.1212
-0.0366, -0.1087,  0.3449, -0.1025,  0.4086,  0.2475
 0.3633,  0.3943,  0.2372, -0.6980,  0.5216,  0.1925
-0.5325, -0.6466, -0.2178, -0.3589,  0.6310,  0.3568
 0.2271,  0.5200, -0.1447, -0.8011, -0.7699,  0.3128
 0.6415,  0.1993,  0.3777, -0.0178, -0.8237,  0.2181
-0.5298, -0.0768, -0.6028, -0.9490,  0.4588,  0.4356
 0.6870, -0.1431,  0.7294,  0.3141,  0.1621,  0.1632
-0.5985,  0.0591,  0.7889, -0.3900,  0.7419,  0.2945
 0.3661,  0.7984, -0.8486,  0.7572, -0.6183,  0.3449
 0.6995,  0.3342, -0.3113, -0.6972,  0.2707,  0.1712
 0.2565,  0.9126,  0.1798, -0.6043, -0.1413,  0.2893
-0.3265,  0.9839, -0.2395,  0.9854,  0.0376,  0.4770
 0.2690, -0.1722,  0.9818,  0.8599, -0.7015,  0.3954
-0.2102, -0.0768,  0.1219,  0.5607, -0.0256,  0.3949
 0.8216, -0.9555,  0.6422, -0.6231,  0.3715,  0.0801
-0.2896,  0.9484, -0.7545, -0.6249,  0.7789,  0.4370
-0.9985, -0.5448, -0.7092, -0.5931,  0.7926,  0.5402

Test data:

# synthetic_test_40.txt
#
 0.7462,  0.4006, -0.0590,  0.6543, -0.0083,  0.1935
 0.8495, -0.2260, -0.0142, -0.4911,  0.7699,  0.1078
-0.2335, -0.4049,  0.4352, -0.6183, -0.7636,  0.5088
 0.1810, -0.5142,  0.2465,  0.2767, -0.3449,  0.3136
-0.8650,  0.7611, -0.0801,  0.5277, -0.4922,  0.7140
-0.2358, -0.7466, -0.5115, -0.8413, -0.3943,  0.4533
 0.4834,  0.2300,  0.3448, -0.9832,  0.3568,  0.1360
-0.6502, -0.6300,  0.6885,  0.9652,  0.8275,  0.3046
-0.3053,  0.5604,  0.0929,  0.6329, -0.0325,  0.4756
-0.7995,  0.0740, -0.2680,  0.2086,  0.9176,  0.4565
-0.2144, -0.2141,  0.5813,  0.2902, -0.2122,  0.4119
-0.7278, -0.0987, -0.3312, -0.5641,  0.8515,  0.4438
 0.3793,  0.1976,  0.4933,  0.0839,  0.4011,  0.1905
-0.8568,  0.9573, -0.5272,  0.3212, -0.8207,  0.7415
-0.5785,  0.0056, -0.7901, -0.2223,  0.0760,  0.5551
 0.0735, -0.2188,  0.3925,  0.3570,  0.3746,  0.2191
 0.1230, -0.2838,  0.2262,  0.8715,  0.1938,  0.2878
 0.4792, -0.9248,  0.5295,  0.0366, -0.9894,  0.3149
-0.4456,  0.0697,  0.5359, -0.8938,  0.0981,  0.3879
 0.8629, -0.8505, -0.4464,  0.8385,  0.5300,  0.1769
 0.1995,  0.6659,  0.7921,  0.9454,  0.9970,  0.2330
-0.0249, -0.3066, -0.2927, -0.4923,  0.8220,  0.2437
 0.4513, -0.9481, -0.0770, -0.4374, -0.9421,  0.2879
-0.3405,  0.5931, -0.3507, -0.3842,  0.8562,  0.3987
 0.9538,  0.0471,  0.9039,  0.7760,  0.0361,  0.1706
-0.0887,  0.2104,  0.9808,  0.5478, -0.3314,  0.4128
-0.8220, -0.6302,  0.0537, -0.1658,  0.6013,  0.4306
-0.4123, -0.2880,  0.9074, -0.0461, -0.4435,  0.5144
 0.0060,  0.2867, -0.7775,  0.5161,  0.7039,  0.3599
-0.7968, -0.5484,  0.9426, -0.4308,  0.8148,  0.2979
 0.7811,  0.8450, -0.6877,  0.7594,  0.2640,  0.2362
-0.6802, -0.1113, -0.8325, -0.6694, -0.6056,  0.6544
 0.3821,  0.1476,  0.7466, -0.5107,  0.2592,  0.1648
 0.7265,  0.9683, -0.9803, -0.4943, -0.5523,  0.2454
-0.9049, -0.9797, -0.0196, -0.9090, -0.4433,  0.6447
-0.4607,  0.1811, -0.2389,  0.4050, -0.0078,  0.5229
 0.2664, -0.2932, -0.4259, -0.7336,  0.8742,  0.1834
-0.4507,  0.1029, -0.6294, -0.1158, -0.6294,  0.6081
 0.8948, -0.0124,  0.9278,  0.2899, -0.0314,  0.1534
-0.1323, -0.8813, -0.0146, -0.0697,  0.6135,  0.2386
This entry was posted in Machine Learning. Bookmark the permalink.

Leave a Reply