GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent

Abstract

Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The implementation is available under: https://github.com/s-marton/GradTree

Cite

Text

Marton et al. "GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent." AAAI Conference on Artificial Intelligence, 2024. doi:10.1609/AAAI.V38I13.29345

Markdown

[Marton et al. "GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent." AAAI Conference on Artificial Intelligence, 2024.](https://mlanthology.org/aaai/2024/marton2024aaai-gradtree/) doi:10.1609/AAAI.V38I13.29345

BibTeX

@inproceedings{marton2024aaai-gradtree,
  title     = {{GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent}},
  author    = {Marton, Sascha and Lüdtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
  booktitle = {AAAI Conference on Artificial Intelligence},
  year      = {2024},
  pages     = {14323-14331},
  doi       = {10.1609/AAAI.V38I13.29345},
  url       = {https://mlanthology.org/aaai/2024/marton2024aaai-gradtree/}
}