PMML
The Predictive Model Markup Language (PMML) is an XML-based language which provides a way for applications to define statistical and data mining models and to share models between PMML compliant applications.
Specifics
- CatBoost exports models to PMML version 4.3.
- Categorical features must be interpreted as one-hot encoded during the training if present in the training dataset. This can be accomplished by setting the --one-hot-max-size/one_hot_max_size parameter to a value that is greater than the maximum number of unique categorical feature values among all categorical features in the dataset.
Multiclassification models are not currently supported.
Models saved as PMML cannot be currently loaded by CatBoostlibraries/executable. Use this format if the model is intended to be used with external Machine Learning libraries.
The native model format (cbm) used by CatBoost libraries/executable is usually faster for applying on common x86-64 platforms because it is optimized to work with CatBoost-specific oblivious trees structure.
Model parameters
Inputs
<DataField name="<feature_name>" optype="continuous" dataType="float"/>
<DataField name="<feature_name>" optype="categorical" dataType="string"/>
Outputs
- Classification
- Only binary classification is currently supported.
<DataField name="prediction" optype="categorical" dataType="boolean"/>
- Regression
-
<DataField name="prediction" optype="continuous" dataType="double"/>
Examples
The following examples use the Python package for training and the Java Evaluator API for PMML for applying the model.
Binary classification
Training:
import catboost
from sklearn import datasets
train_data = datasets.load_breast_cancer()
model = catboost.CatBoostClassifier(loss_function='Logloss')
train_dataset = catboost.Pool(
train_data.data,
label=train_data.target,
feature_names=list(train_data.feature_names)
)
model.fit(train_dataset)
# Save model to PMML format
model.save_model(
"breast_cancer.pmml",
format="pmml",
export_parameters={
'pmml_copyright': 'my copyright (c)',
'pmml_description': 'test model for BinaryClassification',
'pmml_model_version': '1'
}
)
package com.mycompany.app;
import java.io.*;
import java.util.*;
import org.xml.sax.SAXException;
import org.dmg.pmml.*;
import org.jpmml.model.*;
import org.jpmml.evaluator.*;
public class App
{
public static void main(String[] args) throws Exception
{
String modelPath = "breast_cancer.pmml";
Evaluator evaluator = new LoadingModelEvaluatorBuilder()
.setLocatable(false)
.setVisitors(new DefaultVisitorBattery())
//.setOutputFilter(OutputFilters.KEEP_FINAL_RESULTS)
.load(new File(modelPath))
.build();
Map<String, Float> inputDataRecord = new HashMap<String,Float>();
inputDataRecord.put("mean radius", 17.99f);
inputDataRecord.put("mean texture", 10.38f);
inputDataRecord.put("mean perimeter", 122.8f);
inputDataRecord.put("mean area", 1001.0f);
inputDataRecord.put("mean smoothness", 0.1184f);
inputDataRecord.put("mean compactness", 0.2776f);
inputDataRecord.put("mean concavity", 0.3001f);
inputDataRecord.put("mean concave points", 0.1471f);
inputDataRecord.put("mean symmetry", 0.2419f);
inputDataRecord.put("mean fractal dimension", 0.07871f);
inputDataRecord.put("radius error", 1.095f);
inputDataRecord.put("texture error", 0.9053f);
inputDataRecord.put("perimeter error", 8.589f);
inputDataRecord.put("area error", 153.4f);
inputDataRecord.put("smoothness error", 0.006399f);
inputDataRecord.put("compactness error", 0.04904f);
inputDataRecord.put("concavity error", 0.05373f);
inputDataRecord.put("concave points error", 0.01587f);
inputDataRecord.put("symmetry error", 0.03003f);
inputDataRecord.put("fractal dimension error", 0.006193f);
inputDataRecord.put("worst radius", 25.38f);
inputDataRecord.put("worst texture", 17.33f);
inputDataRecord.put("worst perimeter", 184.6f);
inputDataRecord.put("worst area", 2019.0f);
inputDataRecord.put("worst smoothness", 0.1622f);
inputDataRecord.put("worst compactness", 0.6656f);
inputDataRecord.put("worst concavity", 0.7119f);
inputDataRecord.put("worst concave points", 0.2654f);
inputDataRecord.put("worst symmetry", 0.4601f);
inputDataRecord.put("worst fractal dimension", 0.1189f);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<? extends InputField> inputFields = evaluator.getInputFields();
for(InputField inputField : inputFields){
FieldName inputName = inputField.getName();
Object rawValue = inputDataRecord.get(inputName.getValue());
// Transforming an arbitrary user-supplied value to a known-good PMML value
// The user-supplied value is passed through: 1) outlier treatment, 2) missing value treatment, 3) invalid value treatment and 4) type conversion
FieldValue inputValue = inputField.prepare(rawValue);
arguments.put(inputName, inputValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<? extends TargetField> targetFields = evaluator.getTargetFields();
for(TargetField targetField : targetFields){
FieldName targetName = targetField.getName();
Object targetValue = results.get(targetName);
System.out.println(targetName);
System.out.println(targetValue);
}
}
}
Regression
Training:
import catboost
from sklearn import datasets
train_data = datasets.load_boston()
model = catboost.CatBoostRegressor()
train_dataset = catboost.Pool(
train_data.data,
label=train_data.target,
feature_names=list(train_data.feature_names)
)
model.fit(train_dataset)
# Save model to PMML format
model.save_model(
"boston.pmml",
format="pmml",
export_parameters={
'pmml_copyright': 'my copyright (c)',
'pmml_description': 'test model for Regression',
'pmml_model_version': '1'
}
)
Applying:
package com.mycompany.app;
import java.io.*;
import java.util.*;
import org.xml.sax.SAXException;
import org.dmg.pmml.*;
import org.jpmml.model.*;
import org.jpmml.evaluator.*;
public class App
{
public static void main(String[] args) throws Exception
{
String modelPath = "boston.pmml";
Evaluator evaluator = new LoadingModelEvaluatorBuilder()
.setLocatable(false)
.setVisitors(new DefaultVisitorBattery())
//.setOutputFilter(OutputFilters.KEEP_FINAL_RESULTS)
.load(new File(modelPath))
.build();
Map<String, Float> inputDataRecord = new HashMap<String,Float>();
inputDataRecord.put("CRIM", 0.00632f);
inputDataRecord.put("ZN", 18.0f);
inputDataRecord.put("INDUS", 2.31f);
inputDataRecord.put("CHAS", 0.0f);
inputDataRecord.put("NOX", 0.538f);
inputDataRecord.put("RM", 6.575f);
inputDataRecord.put("AGE", 65.2f);
inputDataRecord.put("DIS", 4.09f);
inputDataRecord.put("RAD", 1.0f);
inputDataRecord.put("TAX", 296.0f);
inputDataRecord.put("PTRATIO", 15.3f);
inputDataRecord.put("B", 396.9f);
inputDataRecord.put("LSTAT", 4.98f);
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<? extends InputField> inputFields = evaluator.getInputFields();
for(InputField inputField : inputFields){
FieldName inputName = inputField.getName();
Object rawValue = inputDataRecord.get(inputName.getValue());
// Transforming an arbitrary user-supplied value to a known-good PMML value
// The user-supplied value is passed through: 1) outlier treatment, 2) missing value treatment, 3) invalid value treatment and 4) type conversion
FieldValue inputValue = inputField.prepare(rawValue);
arguments.put(inputName, inputValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<? extends TargetField> targetFields = evaluator.getTargetFields();
for(TargetField targetField : targetFields){
FieldName targetName = targetField.getName();
Object targetValue = results.get(targetName);
System.out.println(targetName);
System.out.println(targetValue);
}
}
}