PMML
The Predictive Model Markup Language (PMML) is an XML-based language which provides a way for applications to define statistical and data mining models and to share models between PMML compliant applications.
Specifics
-
CatBoost exports models to PMML version 4.3.
-
Categorical features must be interpreted as one-hot encoded during the training if present in the training dataset. This can be accomplished by setting the
--one-hot-max-size
/one_hot_max_size
parameter to a value that is greater than the maximum number of unique categorical feature values among all categorical features in the dataset. -
Multiclassification models are not currently supported.
-
Models saved as PMML cannot be currently loaded by CatBoost libraries/executable. Use this format if the model is intended to be used with external Machine Learning libraries.
-
The native model format (cbm) used by CatBoost libraries/executable is usually faster for applying on common x86-64 platforms because it is optimized to work with CatBoost-specific oblivious trees structure.
Model parameters
Inputs
Numerical features description format:
<DataField name="<feature_name>" optype="continuous" dataType="float"/>
Categorical features description format:
<DataField name="<feature_name>" optype="categorical" dataType="string"/>
Outputs
Classification
Only binary classification is currently supported.
<DataField name="prediction" optype="categorical" dataType="boolean"/>
Regression
<DataField name="prediction" optype="continuous" dataType="double"/>
Examples
The following examples use the Python package for training and the Java Evaluator API for PMML for applying the model.
Binary classification
Training:
import catboost
from sklearn import datasets
train_data = datasets.load_breast_cancer()
model = catboost.CatBoostClassifier(loss_function='Logloss')
train_dataset = catboost.Pool(
train_data.data,
label=train_data.target,
feature_names=list(train_data.feature_names)
)
model.fit(train_dataset)
# Save model to PMML format
model.save_model(
"breast_cancer.pmml",
format="pmml",
export_parameters={
'pmml_copyright': 'my copyright (c)',
'pmml_description': 'test model for BinaryClassification',
'pmml_model_version': '1'
}
)
Applying:
package com.mycompany.app;
import java.io.*;
import java.util.*;
import org.xml.sax.SAXException;
import org.dmg.pmml.*;
import org.jpmml.model.*;
import org.jpmml.evaluator.*;
public class App
{
public static void main(String[] args) throws Exception
{
String modelPath = "breast_cancer.pmml";
Evaluator evaluator = new LoadingModelEvaluatorBuilder()
.setLocatable(false)
.setVisitors(new DefaultVisitorBattery())
//.setOutputFilter(OutputFilters.KEEP_FINAL_RESULTS)
.load(new File(modelPath))
.build();
Map<String, Float> inputDataRecord = new HashMap<String,Float>();
inputDataRecord.put("mean radius", 17.99f);
inputDataRecord.put("mean texture", 10.38f);
inputDataRecord.put("mean perimeter", 122.8f);
inputDataRecord.put("mean area", 1001.0f);
inputDataRecord.put("mean smoothness", 0.1184f);
inputDataRecord.put("mean compactness", 0.2776f);
inputDataRecord.put("mean concavity", 0.3001f);
inputDataRecord.put("mean concave points", 0.1471f);
inputDataRecord.put("mean symmetry", 0.2419f);
inputDataRecord.put("mean fractal dimension", 0.07871f);
inputDataRecord.put("radius error", 1.095f);
inputDataRecord.put("texture error", 0.9053f);
inputDataRecord.put("perimeter error", 8.589f);
inputDataRecord.put("area error", 153.4f);
inputDataRecord.put("smoothness error", 0.006399f);
inputDataRecord.put("compactness error", 0.04904f);
inputDataRecord.put("concavity error", 0.05373f);
inputDataRecord.put("concave points error", 0.01587f);
inputDataRecord.put("symmetry error", 0.03003f);
inputDataRecord.put("fractal dimension error", 0.006193f);
inputDataRecord.put("worst radius", 25.38f);
inputDataRecord.put("worst texture", 17.33f);
inputDataRecord.put("worst perimeter", 184.6f);
inputDataRecord.put("worst area", 2019.0f);
inputDataRecord.put("worst smoothness", 0.1622f);
inputDataRecord.put("worst compactness", 0.6656f);
inputDataRecord.put("worst concavity", 0.7119f);
inputDataRecord.put("worst concave points", 0.2654f);
inputDataRecord.put("worst symmetry", 0.4601f);
inputDataRecord.put("worst fractal dimension", 0.1189f);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<? extends InputField> inputFields = evaluator.getInputFields();
for(InputField inputField : inputFields){
FieldName inputName = inputField.getName();
Object rawValue = inputDataRecord.get(inputName.getValue());
// Transforming an arbitrary user-supplied value to a known-good PMML value
// The user-supplied value is passed through: 1) outlier treatment, 2) missing value treatment, 3) invalid value treatment and 4) type conversion
FieldValue inputValue = inputField.prepare(rawValue);
arguments.put(inputName, inputValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<? extends TargetField> targetFields = evaluator.getTargetFields();
for(TargetField targetField : targetFields){
FieldName targetName = targetField.getName();
Object targetValue = results.get(targetName);
System.out.println(targetName);
System.out.println(targetValue);
}
}
}
Regression
Training:
import catboost
from sklearn import datasets
train_data = datasets.load_boston()
model = catboost.CatBoostRegressor()
train_dataset = catboost.Pool(
train_data.data,
label=train_data.target,
feature_names=list(train_data.feature_names)
)
model.fit(train_dataset)
# Save model to PMML format
model.save_model(
"boston.pmml",
format="pmml",
export_parameters={
'pmml_copyright': 'my copyright (c)',
'pmml_description': 'test model for Regression',
'pmml_model_version': '1'
}
)
Applying:
package com.mycompany.app;
import java.io.*;
import java.util.*;
import org.xml.sax.SAXException;
import org.dmg.pmml.*;
import org.jpmml.model.*;
import org.jpmml.evaluator.*;
public class App
{
public static void main(String[] args) throws Exception
{
String modelPath = "boston.pmml";
Evaluator evaluator = new LoadingModelEvaluatorBuilder()
.setLocatable(false)
.setVisitors(new DefaultVisitorBattery())
//.setOutputFilter(OutputFilters.KEEP_FINAL_RESULTS)
.load(new File(modelPath))
.build();
Map<String, Float> inputDataRecord = new HashMap<String,Float>();
inputDataRecord.put("CRIM", 0.00632f);
inputDataRecord.put("ZN", 18.0f);
inputDataRecord.put("INDUS", 2.31f);
inputDataRecord.put("CHAS", 0.0f);
inputDataRecord.put("NOX", 0.538f);
inputDataRecord.put("RM", 6.575f);
inputDataRecord.put("AGE", 65.2f);
inputDataRecord.put("DIS", 4.09f);
inputDataRecord.put("RAD", 1.0f);
inputDataRecord.put("TAX", 296.0f);
inputDataRecord.put("PTRATIO", 15.3f);
inputDataRecord.put("B", 396.9f);
inputDataRecord.put("LSTAT", 4.98f);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<? extends InputField> inputFields = evaluator.getInputFields();
for(InputField inputField : inputFields){
FieldName inputName = inputField.getName();
Object rawValue = inputDataRecord.get(inputName.getValue());
// Transforming an arbitrary user-supplied value to a known-good PMML value
// The user-supplied value is passed through: 1) outlier treatment, 2) missing value treatment, 3) invalid value treatment and 4) type conversion
FieldValue inputValue = inputField.prepare(rawValue);
arguments.put(inputName, inputValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<? extends TargetField> targetFields = evaluator.getTargetFields();
for(TargetField targetField : targetFields){
FieldName targetName = targetField.getName();
Object targetValue = results.get(targetName);
System.out.println(targetName);
System.out.println(targetValue);
}
}
}