Blend trees and counters of two or more trained CatBoost models into a new model. Leaf values can be individually weighted for each input model. For example, it may be useful to blend models trained on different validation datasets.
Method call format
sum_models(models, weights=None, ctr_merge_policy='IntersectingCountersAverage')
A list of models to blend.
list of CatBoost models
A list of weights for the leaf values of each model. The length of this list must be equal to the number of blended models.
А list of weights equal to
1.0/N for N blended models gives the average prediction. For example, the following list of weights gives the average prediction for four blended models:
list of numbers
None (leaf values weights are set to 1 for all models)
The counters merging policy. Possible values:
- FailIfCtrsIntersects — Ensure that the models have zero intersecting counters.
- LeaveMostDiversifiedTable — Use the most diversified counters by the count of unique hash values.
- IntersectingCountersAverage — Use the average ctr counter values in the intersecting bins.
- The bias of the models sum is equal to the weighted sum of models biases.
- The scale of the models sum is equal to 1, leaf values are scaled before the summation.
Type of return value
from catboost import CatBoostClassifier, Pool, sum_models from catboost.datasets import amazon import numpy as np from sklearn.model_selection import train_test_split train_df, _ = amazon() y = train_df.ACTION X = train_df.drop('ACTION', axis=1) categorical_features_indices = np.where(X.dtypes != np.float) X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.8, random_state=42) train_pool = Pool(X_train, y_train, cat_features=categorical_features_indices) validate_pool = Pool(X_validation, y_validation, cat_features=categorical_features_indices) models =  for i in range(5): model = CatBoostClassifier(iterations=100, random_seed=i) model.fit(train_pool, eval_set=validate_pool) models.append(model) models_avrg = sum_models(models, weights=[1.0/len(models)] * len(models))