GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent
Abstract
Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on a wide range of binary classification benchmarks and is available under: https://github.com/s-marton/GradTree
Cite
Text
Marton et al. "GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent." NeurIPS 2023 Workshops: TRL, 2023.Markdown
[Marton et al. "GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent." NeurIPS 2023 Workshops: TRL, 2023.](https://mlanthology.org/neuripsw/2023/marton2023neuripsw-gradtree/)BibTeX
@inproceedings{marton2023neuripsw-gradtree,
title = {{GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent}},
author = {Marton, Sascha and Lüdtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
booktitle = {NeurIPS 2023 Workshops: TRL},
year = {2023},
url = {https://mlanthology.org/neuripsw/2023/marton2023neuripsw-gradtree/}
}