plot_tree
Visualize the CatBoost decision trees.
Dependencies
Method call format
plot_tree(tree_idx, pool=None)
Parameters
tree_idx
Description
The index of the tree from the model that should be visualized.
Possible types
int
Default value
Required parameter
pool
Description
An optional parameter for models that contain only float features. Allows to pass a pool and label features with their external indices from this pool. If the pool is not input, internal indices are used.
For example, for a semicolon-separated pool with 2 features f1;label;f2
the external feature indices are 0 and 2, while the internal indices are 0 and 1 respectively.
Possible types
catboost.Pool
Default value
None
Required for models with one-hot encoded categorical feature
Type of return value
A graphviz.dot.Digraph object describing the visualized tree.
Inner vertices of the tree correspond to splits, and specify factor names and borders used in splits.
Leaf vertices contain raw values predicted by the tree (RawFormulaVal, see Model values).
For MultiClass models, leaves contain ClassCount values (with zero sum). Class of a leaf can be obtained as argMax of this array of values in the leaf.
For MultiRMSE models, leaves contain one value for each label.
Usage examples
import numpy as np
import catboost
from catboost import CatBoostRegressor, Pool
from catboost.datasets import titanic
titanic_df = titanic()
X = titanic_df[0].drop('Survived',axis=1)
y = titanic_df[0].Survived
is_cat = (X.dtypes != float)
for feature, feat_is_cat in is_cat.to_dict().items():
if feat_is_cat:
X[feature].fillna("NAN", inplace=True)
cat_features_index = np.where(is_cat)[0]
pool = Pool(X, y, cat_features=cat_features_index, feature_names=list(X.columns))
model = CatBoostRegressor(
max_depth=2, verbose=False, max_ctr_complexity=1, iterations=2).fit(pool)
model.plot_tree(
tree_idx=0,
pool=pool
)
An example of a plotted tree:
Tutorial
Refer to the Visualization of CatBoost decision trees tutorial for details.