11

When m is the amount of features and n is the amount of samples, the python scikit-learn site (http://scikit-learn.org/stable/modules/tree.html) states that the runtime to construct a binary decision tree is mnlog(n).

I understand that the log(n) comes from the average height of the tree after splitting. I understand that at each split, you have to look at each feature (m) and choose the best one to split on. I understand that this is done by calculating a "best metric" (in my case, a gini impurity) for each sample at that node (n). However, to find the best split, doesn't this mean that you would have to look at each possible way to split the samples for each feature? And wouldn't that be something like 2^n-1 * m rather than just mn? Am I thinking about this wrong? Any advice would help. Thank you.

templatetypedef
  • 362,284
  • 104
  • 897
  • 1,065
iltp38
  • 519
  • 2
  • 5
  • 13
  • Could it be because we use a greedy approach to get a (good tree-low time) trade-off and as a result not get the best possible tree (i.e., maximally compact)? Getting the best possible tree is supposed to be NP-hard, which I believe would have the complexity you mentioned. – rahs Mar 11 '21 at 17:57

1 Answers1

16

One way to build a decision tree would be, at each point, to do something like this:

  • For each possible feature to split on:
    • Find the best possible split for that feature.
    • Determine the "goodness" of this fit.
  • Of all the options tried above, take the best and use that for the split.

The question is how do perform each step. If you have continuous data, a common technique for finding the best possible split would be to sort the data into ascending order along that data point, then consider all possible partition points between those data points and taking the one that minimizes the entropy. This sorting step takes time O(n log n), which dominates the runtime. Since we're doing that for each of the O(m) features, the runtime ends up working out to O(mn log n) total work done per node.

templatetypedef
  • 362,284
  • 104
  • 897
  • 1,065
  • 1
    Even if it is sorted, wouldn't finding the best possible split still take 2*n time for each feature? Since you will have to check each possible way to split the data? This grows faster than n log n, so I thought that would dominate the runtime. – iltp38 Dec 10 '15 at 22:53
  • @iltp38 While you're right that there are 2^n different partitions of the data into two sets, remember that decision trees are built by constructing some simple rule you can use to determine which subtree to descend into. In the context of decision trees like the ones you're describing, this is usually done by picking some simple splitting criterion like "pick some individual feature, pick a threshold, and split the points into 'ones below the threshold' and 'ones above the threshold.'" This reduces the number of possible splits dramatically. (continued...) – templatetypedef Dec 10 '15 at 23:40
  • @iltp38 It also ensures that the tree is usable. After all, when you get a new test point, you need to know how you're going to determine which direction to go at each point, and if you picked an arbitrary clustering at the node you won't necessarily know which partition to descend into. – templatetypedef Dec 10 '15 at 23:41
  • @templatetypedef Why a single sort is enough? Once you split the data you wouldn't you have to sort again? Because the order according to one feature might be different that if you order by another feature. So you have to do sorting again at each step. Here, https://sebastianraschka.com/pdf/lecture-notes/stat479fs18/06_trees_notes.pdf, the quoted complexity is actually $mn^2log(n)$ – them Mar 13 '20 at 05:03
  • @them Read a little farther down in the in the text you linked. He states that it's mlog(n) with "caching tricks". – tkunk May 11 '22 at 19:11