Packages

class CatBoostClassificationModel extends ProbabilisticClassificationModel[Vector, CatBoostClassificationModel] with CatBoostModelTrait[CatBoostClassificationModel]

Classification model trained by CatBoost. Use CatBoostClassifier to train it

Serialization

Supports standard Spark MLLib serialization. Data can be saved to distributed filesystem like HDFS or local files. When saved to path two files are created: -<path>/metadata which contains Spark-specific metadata in JSON format -<path>/model which contains model in usual CatBoost format which can be read using other local CatBoost APIs (if stored in a distributed filesystem it has to be copied to the local filesystem first).

Saving to and loading from local files in standard CatBoost model formats is also supported.

Examples:
  1. Save model

    val trainPool : Pool = ... init Pool ...
    val classifier = new CatBoostClassifier
    val model = classifier.fit(trainPool)
    val path = "/home/user/catboost_spark_models/model0"
    model.write.save(path)
  2. ,
  3. Load model

    val dataFrameForPrediction : DataFrame = ... init DataFrame ...
    val path = "/home/user/catboost_spark_models/model0"
    val model = CatBoostClassificationModel.load(path)
    val predictions = model.transform(dataFrameForPrediction)
    predictions.show()
  4. ,
  5. Save as a native model

    val trainPool : Pool = ... init Pool ...
    val classifier = new CatBoostClassifier
    val model = classifier.fit(trainPool)
    val path = "/home/user/catboost_native_models/model0.cbm"
    model.saveNativeModel(path)
  6. ,
  7. Load native model

    val dataFrameForPrediction : DataFrame = ... init DataFrame ...
    val path = "/home/user/catboost_native_models/model0.cbm"
    val model = CatBoostClassificationModel.loadNativeModel(path)
    val predictions = model.transform(dataFrameForPrediction)
    predictions.show()
Linear Supertypes
CatBoostModelTrait[CatBoostClassificationModel], MLWritable, ProbabilisticClassificationModel[Vector, CatBoostClassificationModel], ProbabilisticClassifierParams, HasThresholds, HasProbabilityCol, ClassificationModel[Vector, CatBoostClassificationModel], ClassifierParams, HasRawPredictionCol, PredictionModel[Vector, CatBoostClassificationModel], PredictorParams, HasPredictionCol, HasFeaturesCol, HasLabelCol, Model[CatBoostClassificationModel], Transformer, PipelineStage, Logging, Params, Serializable, Serializable, Identifiable, AnyRef, Any
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. CatBoostClassificationModel
  2. CatBoostModelTrait
  3. MLWritable
  4. ProbabilisticClassificationModel
  5. ProbabilisticClassifierParams
  6. HasThresholds
  7. HasProbabilityCol
  8. ClassificationModel
  9. ClassifierParams
  10. HasRawPredictionCol
  11. PredictionModel
  12. PredictorParams
  13. HasPredictionCol
  14. HasFeaturesCol
  15. HasLabelCol
  16. Model
  17. Transformer
  18. PipelineStage
  19. Logging
  20. Params
  21. Serializable
  22. Serializable
  23. Identifiable
  24. AnyRef
  25. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new CatBoostClassificationModel(nativeModel: TFullModel)
  2. new CatBoostClassificationModel(uid: String, nativeModel: TFullModel = null, nativeDimension: Int)

Value Members

  1. final def !=(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int
    Definition Classes
    AnyRef → Any
  3. final def $[T](param: Param[T]): T
    Attributes
    protected
    Definition Classes
    Params
  4. final def ==(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  5. final def asInstanceOf[T0]: T0
    Definition Classes
    Any
  6. final def clear(param: Param[_]): CatBoostClassificationModel.this.type
    Definition Classes
    Params
  7. def clone(): AnyRef
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  8. def copy(extra: ParamMap): CatBoostClassificationModel
    Definition Classes
    CatBoostClassificationModel → Model → Transformer → PipelineStage → Params
  9. def copyValues[T <: Params](to: T, extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  10. final def defaultCopy[T <: Params](extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  11. final def eq(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  12. def equals(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  13. def explainParam(param: Param[_]): String
    Definition Classes
    Params
  14. def explainParams(): String
    Definition Classes
    Params
  15. def extractInstances(dataset: Dataset[_], numClasses: Int): RDD[Instance]
    Attributes
    protected
    Definition Classes
    ClassifierParams
  16. def extractInstances(dataset: Dataset[_], validateInstance: (Instance) ⇒ Unit): RDD[Instance]
    Attributes
    protected
    Definition Classes
    PredictorParams
  17. def extractInstances(dataset: Dataset[_]): RDD[Instance]
    Attributes
    protected
    Definition Classes
    PredictorParams
  18. final def extractParamMap(): ParamMap
    Definition Classes
    Params
  19. final def extractParamMap(extra: ParamMap): ParamMap
    Definition Classes
    Params
  20. final val featuresCol: Param[String]
    Definition Classes
    HasFeaturesCol
  21. def featuresDataType: DataType
    Attributes
    protected
    Definition Classes
    PredictionModel
  22. def finalize(): Unit
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  23. final def get[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  24. def getAdditionalColumnsForApply: Seq[StructField]
    Attributes
    protected
    Definition Classes
    CatBoostClassificationModel → CatBoostModelTrait
  25. final def getClass(): Class[_]
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  26. final def getDefault[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  27. def getFeatureImportance(fstrType: EFstrType = EFstrType.FeatureImportance, data: Pool = null, calcType: ECalcTypeShapValues = ECalcTypeShapValues.Regular): Array[Double]

    fstrType

    Supported values are FeatureImportance, PredictionValuesChange, LossFunctionChange, PredictionDiff

    data

    if fstrType is PredictionDiff it is required and must contain 2 samples if fstrType is PredictionValuesChange this param is required in case if model was explicitly trained with flag to store no leaf weights. otherwise it can be null

    calcType

    Used only for PredictionValuesChange. Possible values:

    • Regular Calculate regular SHAP values
    • Approximate Calculate approximate SHAP values
    • Exact Calculate exact SHAP values
    returns

    array of feature importances (index corresponds to the order of features in the model)

    Definition Classes
    CatBoostModelTrait
  28. def getFeatureImportanceInteraction(): Array[FeatureInteractionScore]

    returns

    array of feature interaction scores

    Definition Classes
    CatBoostModelTrait
  29. def getFeatureImportancePrettified(fstrType: EFstrType = EFstrType.FeatureImportance, data: Pool = null, calcType: ECalcTypeShapValues = ECalcTypeShapValues.Regular): Array[FeatureImportance]

    fstrType

    Supported values are FeatureImportance, PredictionValuesChange, LossFunctionChange, PredictionDiff

    data

    if fstrType is PredictionDiff it is required and must contain 2 samples if fstrType is PredictionValuesChange this param is required in case if model was explicitly trained with flag to store no leaf weights. otherwise it can be null

    calcType

    Used only for PredictionValuesChange. Possible values:

    • Regular Calculate regular SHAP values
    • Approximate Calculate approximate SHAP values
    • Exact Calculate exact SHAP values
    returns

    array of feature importances sorted in descending order by importance

    Definition Classes
    CatBoostModelTrait
  30. def getFeatureImportanceShapInteractionValues(data: Pool, featureIndices: Pair[Int, Int] = null, featureNames: Pair[String, String] = null, preCalcMode: EPreCalcShapValues = EPreCalcShapValues.Auto, calcType: ECalcTypeShapValues = ECalcTypeShapValues.Regular, outputColumns: Array[String] = null): DataFrame

    SHAP interaction values are calculated for all features pairs if nor featureIndices nor featureNames are specified.

    SHAP interaction values are calculated for all features pairs if nor featureIndices nor featureNames are specified.

    data

    dataset to calculate SHAP interaction values

    featureIndices

    (optional) pair of feature indices to calculate SHAP interaction values for.

    featureNames

    (optional) pair of feature names to calculate SHAP interaction values for.

    preCalcMode

    Possible values:

    • Auto Use direct SHAP Values calculation only if data size is smaller than average leaves number (the best of two strategies below is chosen).
    • UsePreCalc Calculate SHAP Values for every leaf in preprocessing. Final complexity is O(NT(D+F))+O(TL2 D2) where N is the number of documents(objects), T - number of trees, D - average tree depth, F - average number of features in tree, L - average number of leaves in tree This is much faster (because of a smaller constant) than direct calculation when N >> L
    • NoPreCalc Use direct SHAP Values calculation calculation with complexity O(NTLD^2). Direct algorithm is faster when N < L (algorithm from https://arxiv.org/abs/1802.03888)
    calcType

    Possible values:

    • Regular Calculate regular SHAP values
    • Approximate Calculate approximate SHAP values
    • Exact Calculate exact SHAP values
    outputColumns

    columns from data to add to output DataFrame, if null - add all columns

    returns

    • for binclass or regression: DataFrame which contains outputColumns and "featureIdx1", "featureIdx2", "shapInteractionValue" columns
    • for multiclass: DataFrame which contains outputColumns and "classIdx", "featureIdx1", "featureIdx2", "shapInteractionValue" columns
    Definition Classes
    CatBoostModelTrait
  31. def getFeatureImportanceShapValues(data: Pool, preCalcMode: EPreCalcShapValues = EPreCalcShapValues.Auto, calcType: ECalcTypeShapValues = ECalcTypeShapValues.Regular, modelOutputType: EExplainableModelOutput = EExplainableModelOutput.Raw, referenceData: Pool = null, outputColumns: Array[String] = null): DataFrame

    data

    dataset to calculate SHAP values for

    preCalcMode

    Possible values:

    • Auto Use direct SHAP Values calculation only if data size is smaller than average leaves number (the best of two strategies below is chosen).
    • UsePreCalc Calculate SHAP Values for every leaf in preprocessing. Final complexity is O(NT(D+F))+O(TL2 D2) where N is the number of documents(objects), T - number of trees, D - average tree depth, F - average number of features in tree, L - average number of leaves in tree This is much faster (because of a smaller constant) than direct calculation when N >> L
    • NoPreCalc Use direct SHAP Values calculation calculation with complexity O(NTLD^2). Direct algorithm is faster when N < L (algorithm from https://arxiv.org/abs/1802.03888)
    calcType

    Possible values:

    • Regular Calculate regular SHAP values
    • Approximate Calculate approximate SHAP values
    • Exact Calculate exact SHAP values
    referenceData

    reference data for Independent Tree SHAP values from https://arxiv.org/abs/1905.04610v1 if referenceData is not null, then Independent Tree SHAP values are calculated

    outputColumns

    columns from data to add to output DataFrame, if null - add all columns

    returns

    • for regression and binclass models: DataFrame which contains outputColumns and "shapValues" column with Vector of length (n_features + 1) with SHAP values
    • for multiclass models: DataFrame which contains outputColumns and "shapValues" column with Matrix of shape (n_classes x (n_features + 1)) with SHAP values
    Definition Classes
    CatBoostModelTrait
  32. final def getFeaturesCol: String
    Definition Classes
    HasFeaturesCol
  33. final def getLabelCol: String
    Definition Classes
    HasLabelCol
  34. final def getOrDefault[T](param: Param[T]): T
    Definition Classes
    Params
  35. def getParam(paramName: String): Param[Any]
    Definition Classes
    Params
  36. final def getPredictionCol: String
    Definition Classes
    HasPredictionCol
  37. final def getProbabilityCol: String
    Definition Classes
    HasProbabilityCol
  38. final def getRawPredictionCol: String
    Definition Classes
    HasRawPredictionCol
  39. def getResultIteratorForApply(objectsDataProvider: SWIGTYPE_p_NCB__TObjectsDataProviderPtr, dstRows: ArrayBuffer[Array[Any]], localExecutor: TLocalExecutor): Iterator[Row]
    Attributes
    protected
    Definition Classes
    CatBoostClassificationModel → CatBoostModelTrait
  40. def getThresholds: Array[Double]
    Definition Classes
    HasThresholds
  41. final def hasDefault[T](param: Param[T]): Boolean
    Definition Classes
    Params
  42. def hasParam(paramName: String): Boolean
    Definition Classes
    Params
  43. def hasParent: Boolean
    Definition Classes
    Model
  44. def hashCode(): Int
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  45. def initializeLogIfNecessary(isInterpreter: Boolean, silent: Boolean): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  46. def initializeLogIfNecessary(isInterpreter: Boolean): Unit
    Attributes
    protected
    Definition Classes
    Logging
  47. final def isDefined(param: Param[_]): Boolean
    Definition Classes
    Params
  48. final def isInstanceOf[T0]: Boolean
    Definition Classes
    Any
  49. final def isSet(param: Param[_]): Boolean
    Definition Classes
    Params
  50. def isTraceEnabled(): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  51. final val labelCol: Param[String]
    Definition Classes
    HasLabelCol
  52. def log: Logger
    Attributes
    protected
    Definition Classes
    Logging
  53. def logDebug(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  54. def logDebug(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  55. def logError(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  56. def logError(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  57. def logInfo(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  58. def logInfo(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  59. def logName: String
    Attributes
    protected
    Definition Classes
    Logging
  60. def logTrace(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  61. def logTrace(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  62. def logWarning(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  63. def logWarning(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  64. var nativeDimension: Int
    Attributes
    protected
    Definition Classes
    CatBoostClassificationModel → CatBoostModelTrait
  65. final def ne(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  66. final def notify(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  67. final def notifyAll(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  68. def numClasses: Int
    Definition Classes
    CatBoostClassificationModel → ClassificationModel
  69. def numFeatures: Int
    Definition Classes
    PredictionModel
    Annotations
    @Since( "1.6.0" )
  70. lazy val params: Array[Param[_]]
    Definition Classes
    Params
  71. var parent: Estimator[CatBoostClassificationModel]
    Definition Classes
    Model
  72. def predict(features: Vector): Double
    Definition Classes
    ClassificationModel → PredictionModel
  73. def predictProbability(features: Vector): Vector
    Definition Classes
    ProbabilisticClassificationModel
    Annotations
    @Since( "3.0.0" )
  74. def predictRaw(features: Vector): Vector

    Prefer batch computations operating on datasets as a whole for efficiency

    Prefer batch computations operating on datasets as a whole for efficiency

    Definition Classes
    CatBoostClassificationModel → ClassificationModel
  75. final def predictRawImpl(features: Vector): Array[Double]

    Prefer batch computations operating on datasets as a whole for efficiency

    Prefer batch computations operating on datasets as a whole for efficiency

    Definition Classes
    CatBoostModelTrait
  76. final val predictionCol: Param[String]
    Definition Classes
    HasPredictionCol
  77. def probability2prediction(probability: Vector): Double
    Attributes
    protected
    Definition Classes
    ProbabilisticClassificationModel
  78. final val probabilityCol: Param[String]
    Definition Classes
    HasProbabilityCol
  79. def raw2prediction(rawPrediction: Vector): Double
    Attributes
    protected
    Definition Classes
    ProbabilisticClassificationModel → ClassificationModel
  80. def raw2probability(rawPrediction: Vector): Vector
    Attributes
    protected
    Definition Classes
    ProbabilisticClassificationModel
  81. def raw2probabilityInPlace(rawPrediction: Vector): Vector

    Prefer batch computations operating on datasets as a whole for efficiency

    Prefer batch computations operating on datasets as a whole for efficiency

    Attributes
    protected
    Definition Classes
    CatBoostClassificationModel → ProbabilisticClassificationModel
  82. final val rawPredictionCol: Param[String]
    Definition Classes
    HasRawPredictionCol
  83. def save(path: String): Unit
    Definition Classes
    MLWritable
    Annotations
    @Since( "1.6.0" ) @throws( ... )
  84. def saveNativeModel(fileName: String, format: EModelType = EModelType.CatboostBinary, exportParameters: Map[String, Any] = null, pool: Pool = null): Unit

    Save the model to a local file.

    Save the model to a local file.

    fileName

    The path to the output model.

    format

    The output format of the model. Possible values:

    CatboostBinary CatBoost binary format (default).
    AppleCoreML Apple CoreML format (only datasets without categorical features are currently supported).
    Cpp Standalone C++ code (multiclassification models are not currently supported). See the C++ section for details on applying the resulting model.
    Python Standalone Python code (multiclassification models are not currently supported). See the Python section for details on applying the resulting model.
    Json JSON format. Refer to the CatBoost JSON model tutorial for format details.
    Onnx ONNX-ML format (only datasets without categorical features are currently supported). Refer to https://onnx.ai for details.
    Pmml PMML version 4.3 format. Categorical features must be interpreted as one-hot encoded during the training if present in the training dataset. This can be accomplished by setting the --one-hot-max-size/one_hot_max_size parameter to a value that is greater than the maximum number of unique categorical feature values among all categorical features in the dataset. Note. Multiclassification models are not currently supported. See the PMML section for details on applying the resulting model.

    exportParameters

    Additional format-dependent parameters for AppleCoreML, Onnx or Pmml formats. See python API documentation for details.

    pool

    The dataset previously used for training. This parameter is required if the model contains categorical features and the output format is Cpp, Python, or Json.

    Definition Classes
    CatBoostModelTrait
    Example:
    1. val spark = SparkSession.builder()
        .master("local[*]")
        .appName("testSaveLocalModel")
        .getOrCreate()
      
      val pool = Pool.load(
        spark,
        "dsv:///home/user/datasets/my_dataset/train.dsv",
        columnDescription = "/home/user/datasets/my_dataset/cd"
      )
      
      val regressor = new CatBoostRegressor()
      val model = regressor.fit(pool)
      
      // save in CatBoostBinary format
      model.saveNativeModel("/home/user/model/model.cbm")
      
      // save in ONNX format with metadata
      model.saveNativeModel(
        "/home/user/model/model.onnx",
        EModelType.Onnx,
        Map(
          "onnx_domain" -> "ai.catboost",
          "onnx_model_version" -> 1,
          "onnx_doc_string" -> "test model for regression",
          "onnx_graph_name" -> "CatBoostModel_for_regression"
        )
      )
  85. final def set(paramPair: ParamPair[_]): CatBoostClassificationModel.this.type
    Attributes
    protected
    Definition Classes
    Params
  86. final def set(param: String, value: Any): CatBoostClassificationModel.this.type
    Attributes
    protected
    Definition Classes
    Params
  87. final def set[T](param: Param[T], value: T): CatBoostClassificationModel.this.type
    Definition Classes
    Params
  88. final def setDefault(paramPairs: ParamPair[_]*): CatBoostClassificationModel.this.type
    Attributes
    protected
    Definition Classes
    Params
  89. final def setDefault[T](param: Param[T], value: T): CatBoostClassificationModel.this.type
    Attributes
    protected
    Definition Classes
    Params
  90. def setFeaturesCol(value: String): CatBoostClassificationModel
    Definition Classes
    PredictionModel
  91. def setParent(parent: Estimator[CatBoostClassificationModel]): CatBoostClassificationModel
    Definition Classes
    Model
  92. def setPredictionCol(value: String): CatBoostClassificationModel
    Definition Classes
    PredictionModel
  93. def setProbabilityCol(value: String): CatBoostClassificationModel
    Definition Classes
    ProbabilisticClassificationModel
  94. def setRawPredictionCol(value: String): CatBoostClassificationModel
    Definition Classes
    ClassificationModel
  95. def setThresholds(value: Array[Double]): CatBoostClassificationModel
    Definition Classes
    ProbabilisticClassificationModel
  96. final def synchronized[T0](arg0: ⇒ T0): T0
    Definition Classes
    AnyRef
  97. val thresholds: DoubleArrayParam
    Definition Classes
    HasThresholds
  98. def toString(): String
    Definition Classes
    Identifiable → AnyRef → Any
  99. def transform(dataset: Dataset[_]): DataFrame
    Definition Classes
    CatBoostClassificationModel → ProbabilisticClassificationModel → ClassificationModel → PredictionModel → Transformer
  100. def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame
    Definition Classes
    Transformer
    Annotations
    @Since( "2.0.0" )
  101. def transform(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame
    Definition Classes
    Transformer
    Annotations
    @Since( "2.0.0" ) @varargs()
  102. def transformCatBoostImpl(dataset: Dataset[_]): DataFrame
    Attributes
    protected
    Definition Classes
    CatBoostModelTrait
  103. final def transformImpl(dataset: Dataset[_]): DataFrame
    Definition Classes
    ClassificationModel → PredictionModel
  104. def transformPool(dataset: Pool): DataFrame

    This function is useful when the dataset has been already quantized but works with any Pool

    This function is useful when the dataset has been already quantized but works with any Pool

    Definition Classes
    CatBoostModelTrait
  105. def transformSchema(schema: StructType): StructType
    Definition Classes
    ProbabilisticClassificationModel → ClassificationModel → PredictionModel → PipelineStage
  106. def transformSchema(schema: StructType, logging: Boolean): StructType
    Attributes
    protected
    Definition Classes
    PipelineStage
    Annotations
    @DeveloperApi()
  107. val uid: String
    Definition Classes
    CatBoostClassificationModel → Identifiable
  108. def validateAndTransformSchema(schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType
    Attributes
    protected
    Definition Classes
    ProbabilisticClassifierParams → ClassifierParams → PredictorParams
  109. final def wait(): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  110. final def wait(arg0: Long, arg1: Int): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  111. final def wait(arg0: Long): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  112. def write: MLWriter
    Definition Classes
    CatBoostModelTrait → MLWritable

Inherited from CatBoostModelTrait[CatBoostClassificationModel]

Inherited from MLWritable

Inherited from ProbabilisticClassificationModel[Vector, CatBoostClassificationModel]

Inherited from ProbabilisticClassifierParams

Inherited from HasThresholds

Inherited from HasProbabilityCol

Inherited from ClassificationModel[Vector, CatBoostClassificationModel]

Inherited from ClassifierParams

Inherited from HasRawPredictionCol

Inherited from PredictionModel[Vector, CatBoostClassificationModel]

Inherited from PredictorParams

Inherited from HasPredictionCol

Inherited from HasFeaturesCol

Inherited from HasLabelCol

Inherited from Model[CatBoostClassificationModel]

Inherited from Transformer

Inherited from PipelineStage

Inherited from Logging

Inherited from Params

Inherited from Serializable

Inherited from Serializable

Inherited from Identifiable

Inherited from AnyRef

Inherited from Any

Ungrouped