package spark
CatBoost is a machine learning algorithm that uses gradient boosting on decision trees.
Overview
This package provides classes that implement interfaces from Apache Spark Machine Learning Library (MLLib).
For binary and multi- classification problems use CatBoostClassifier, for regression use CatBoostRegressor.
These classes implement usual fit
method of org.apache.spark.ml.Predictor that accept a single
org.apache.spark.sql.DataFrame for training, but you can also use other fit
method that accepts
additional datasets for computing evaluation metrics and overfitting detection similarily to CatBoost's
other APIs.
This package also contains Pool class that is CatBoost's abstraction of a dataset. It contains additional information compared to simple org.apache.spark.sql.DataFrame.
It is also possible to create Pool with quantized features before training by calling quantize
method.
This is useful if this dataset is used for training multiple times and quantization parameters do not
change. Pre-quantized Pool allows to cache quantized features data and so do not re-run
feature quantization step at the start of an each training.
Detailed documentation is available on https://catboost.ai/docs/
- Alphabetic
- By Inheritance
- spark
- AnyRef
- Any
- Hide All
- Show All
- Public
- All
Type Members
-
class
CatBoostClassificationModel extends ProbabilisticClassificationModel[Vector, CatBoostClassificationModel] with CatBoostModelTrait[CatBoostClassificationModel]
Classification model trained by CatBoost.
Classification model trained by CatBoost. Use CatBoostClassifier to train it
Serialization
Supports standard Spark MLLib serialization. Data can be saved to distributed filesystem like HDFS or local files. When saved to
path
two files are created: -<path>/metadata
which contains Spark-specific metadata in JSON format -<path>/model
which contains model in usual CatBoost format which can be read using other local CatBoost APIs (if stored in a distributed filesystem it has to be copied to the local filesystem first).Saving to and loading from local files in standard CatBoost model formats is also supported.
Save model
val trainPool : Pool = ... init Pool ... val classifier = new CatBoostClassifier val model = classifier.fit(trainPool) val path = "/home/user/catboost_spark_models/model0" model.write.save(path)
, Load model
val dataFrameForPrediction : DataFrame = ... init DataFrame ... val path = "/home/user/catboost_spark_models/model0" val model = CatBoostClassificationModel.load(path) val predictions = model.transform(dataFrameForPrediction) predictions.show()
, Save as a native model
val trainPool : Pool = ... init Pool ... val classifier = new CatBoostClassifier val model = classifier.fit(trainPool) val path = "/home/user/catboost_native_models/model0.cbm" model.saveNativeModel(path)
, Load native model
val dataFrameForPrediction : DataFrame = ... init DataFrame ... val path = "/home/user/catboost_native_models/model0.cbm" val model = CatBoostClassificationModel.loadNativeModel(path) val predictions = model.transform(dataFrameForPrediction) predictions.show()
Examples: -
class
CatBoostClassifier extends ProbabilisticClassifier[Vector, CatBoostClassifier, CatBoostClassificationModel] with CatBoostPredictorTrait[CatBoostClassifier, CatBoostClassificationModel] with ClassifierTrainingParamsTrait
Class to train CatBoostClassificationModel
Class to train CatBoostClassificationModel
The default optimized loss function depends on various conditions:
Logloss
— The label column has only two different values or the targetBorder parameter is specified.MultiClass
— The label column has more than two different values and the targetBorder parameter is not specified.
Examples
Binary classification.
val spark = SparkSession.builder() .master("local[*]") .appName("ClassifierTest") .getOrCreate(); val srcDataSchema = Seq( StructField("features", SQLDataTypes.VectorType), StructField("label", StringType) ) val trainData = Seq( Row(Vectors.dense(0.1, 0.2, 0.11), "0"), Row(Vectors.dense(0.97, 0.82, 0.33), "1"), Row(Vectors.dense(0.13, 0.22, 0.23), "1"), Row(Vectors.dense(0.8, 0.62, 0.0), "0") ) val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema)) val trainPool = new Pool(trainDf) val evalData = Seq( Row(Vectors.dense(0.22, 0.33, 0.9), "1"), Row(Vectors.dense(0.11, 0.1, 0.21), "0"), Row(Vectors.dense(0.77, 0.0, 0.0), "1") ) val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema)) val evalPool = new Pool(evalDf) val classifier = new CatBoostClassifier val model = classifier.fit(trainPool, Array[Pool](evalPool)) val predictions = model.transform(evalPool.data) predictions.show()
Multiclassification.
val spark = SparkSession.builder() .master("local[*]") .appName("ClassifierTest") .getOrCreate(); val srcDataSchema = Seq( StructField("features", SQLDataTypes.VectorType), StructField("label", StringType) ) val trainData = Seq( Row(Vectors.dense(0.1, 0.2, 0.11), "1"), Row(Vectors.dense(0.97, 0.82, 0.33), "2"), Row(Vectors.dense(0.13, 0.22, 0.23), "1"), Row(Vectors.dense(0.8, 0.62, 0.0), "0") ) val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema)) val trainPool = new Pool(trainDf) val evalData = Seq( Row(Vectors.dense(0.22, 0.33, 0.9), "2"), Row(Vectors.dense(0.11, 0.1, 0.21), "0"), Row(Vectors.dense(0.77, 0.0, 0.0), "1") ) val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema)) val evalPool = new Pool(evalDf) val classifier = new CatBoostClassifier val model = classifier.fit(trainPool, Array[Pool](evalPool)) val predictions = model.transform(evalPool.data) predictions.show()
Serialization
Supports standard Spark MLLib serialization. Data can be saved to distributed filesystem like HDFS or local files.
Examples== Save:
val classifier = new CatBoostClassifier().setIterations(100) val path = "/home/user/catboost_classifiers/classifier0" classifier.write.save(path)
Load:
val path = "/home/user/catboost_classifiers/classifier0" val classifier = CatBoostClassifier.load(path) val trainPool : Pool = ... init Pool ... val model = classifier.fit(trainPool)
-
trait
CatBoostPredictorTrait[Learner <: Predictor[Vector, Learner, Model], Model <: PredictionModel[Vector, Model]] extends Predictor[Vector, Learner, Model] with DatasetParamsTrait with DefaultParamsWritable
Base trait with common functionality for both CatBoostClassifier and CatBoostRegressor
-
class
CatBoostRegressionModel extends RegressionModel[Vector, CatBoostRegressionModel] with CatBoostModelTrait[CatBoostRegressionModel]
Regression model trained by CatBoost.
Regression model trained by CatBoost. Use CatBoostRegressor to train it
Serialization
Supports standard Spark MLLib serialization. Data can be saved to distributed filesystem like HDFS or local files. When saved to
path
two files are created: -<path>/metadata
which contains Spark-specific metadata in JSON format -<path>/model
which contains model in usual CatBoost format which can be read using other local CatBoost APIs (if stored in a distributed filesystem it has to be copied to the local filesystem first).Saving to and loading from local files in standard CatBoost model formats is also supported.
Save model
val trainPool : Pool = ... init Pool ... val regressor = new CatBoostRegressor val model = regressor.fit(trainPool) val path = "/home/user/catboost_spark_models/model0" model.write.save(path)
, Load model
val dataFrameForPrediction : DataFrame = ... init DataFrame ... val path = "/home/user/catboost_spark_models/model0" val model = CatBoostRegressionModel.load(path) val predictions = model.transform(dataFrameForPrediction) predictions.show()
, Save as a native model
val trainPool : Pool = ... init Pool ... val regressor = new CatBoostRegressor val model = regressor.fit(trainPool) val path = "/home/user/catboost_native_models/model0.cbm" model.saveNativeModel(path)
, Load native model
val dataFrameForPrediction : DataFrame = ... init DataFrame ... val path = "/home/user/catboost_native_models/model0.cbm" val model = CatBoostRegressionModel.loadNativeModel(path) val predictions = model.transform(dataFrameForPrediction) predictions.show()
Examples: -
class
CatBoostRegressor extends CatBoostRegressorBase[Vector, CatBoostRegressor, CatBoostRegressionModel] with CatBoostPredictorTrait[CatBoostRegressor, CatBoostRegressionModel] with RegressorTrainingParamsTrait
Class to train CatBoostRegressionModel The default optimized loss function is
RMSE
Class to train CatBoostRegressionModel The default optimized loss function is
RMSE
Examples
Basic example.
val spark = SparkSession.builder() .master("local[*]") .appName("RegressorTest") .getOrCreate(); val srcDataSchema = Seq( StructField("features", SQLDataTypes.VectorType), StructField("label", StringType) ) val trainData = Seq( Row(Vectors.dense(0.1, 0.2, 0.11), "0.12"), Row(Vectors.dense(0.97, 0.82, 0.33), "0.22"), Row(Vectors.dense(0.13, 0.22, 0.23), "0.34"), Row(Vectors.dense(0.8, 0.62, 0.0), "0.1") ) val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema)) val trainPool = new Pool(trainDf) val evalData = Seq( Row(Vectors.dense(0.22, 0.33, 0.9), "0.1"), Row(Vectors.dense(0.11, 0.1, 0.21), "0.9"), Row(Vectors.dense(0.77, 0.0, 0.0), "0.72") ) val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema)) val evalPool = new Pool(evalDf) val regressor = new CatBoostRegressor val model = regressor.fit(trainPool, Array[Pool](evalPool)) val predictions = model.transform(evalPool.data) predictions.show()
Example with alternative loss function.
...<initialize trainPool, evalPool> val regressor = new CatBoostRegressor().setLossFunction("MAE") val model = regressor.fit(trainPool, Array[Pool](evalPool)) val predictions = model.transform(evalPool.data) predictions.show()
Serialization
Supports standard Spark MLLib serialization. Data can be saved to distributed filesystem like HDFS or local files.
Examples:
Save:
val regressor = new CatBoostRegressor().setLossFunction("MAE") val path = "/home/user/catboost_regressors/regressor0" regressor.write.save(path)
Load:
val path = "/home/user/catboost_regressors/regressor0" val regressor = CatBoostRegressor.load(path) val trainPool : Pool = ... init Pool ... val model = regressor.fit(trainPool)
- class CatBoostTrainingContext extends AnyRef
- type EModelType = ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.EModelType
- class FeatureImportance extends AnyRef
- class FeatureInteractionScore extends AnyRef
-
class
Pool extends Params with HasLabelCol with HasFeaturesCol with HasWeightCol with Logging
CatBoost's abstraction of a dataset.
CatBoost's abstraction of a dataset.
Features data can be stored in raw (features column has org.apache.spark.ml.linalg.Vector type) or quantized (float feature values are quantized into integer bin values, features column has
Array[Byte]
type) form.Raw Pool can be transformed to quantized form using
quantize
method. This is useful if this dataset is used for training multiple times and quantization parameters do not change. Pre-quantized Pool allows to cache quantized features data and so do not re-run feature quantization step at the start of an each training.
Value Members
- object CatBoostClassificationModel extends MLReadable[CatBoostClassificationModel] with Serializable
- object CatBoostClassifier extends DefaultParamsReadable[CatBoostClassifier] with Serializable
- object CatBoostRegressionModel extends MLReadable[CatBoostRegressionModel] with Serializable
- object CatBoostRegressor extends DefaultParamsReadable[CatBoostRegressor] with Serializable
-
object
Pool extends Serializable
Companion object for Pool class that is CatBoost's abstraction of a dataset