Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

SPARK-28985 #262

Merged
merged 3 commits into from
Oct 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion third_party/3/pyspark/ml/_typing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Dict, TypeVar, Union

import pyspark.ml.param
import pyspark.ml.base
import pyspark.ml.param
import pyspark.ml.util
import pyspark.ml.wrapper

ParamMap = Dict[pyspark.ml.param.Param, Any]
PipelineStage = Union[pyspark.ml.base.Estimator, pyspark.ml.base.Transformer]

T = TypeVar("T")
P = TypeVar("P", bound=pyspark.ml.param.Params)
M = TypeVar("M", bound=pyspark.ml.base.Transformer)
JM = TypeVar("JM", bound=pyspark.ml.wrapper.JavaTransformer)
55 changes: 36 additions & 19 deletions third_party/3/pyspark/ml/classification.pyi
Original file line number Diff line number Diff line change
@@ -1,32 +1,48 @@
# Stubs for pyspark.ml.classification (Python 3)

import abc
from typing import Any, Dict, List, Optional, TypeVar
from pyspark.ml._typing import M, P, ParamMap
from pyspark.ml._typing import JM, M, P, T, ParamMap

from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeParams, DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, TreeEnsembleModel
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaPredictorParams, JavaWrapper, JavaTransformer
from pyspark.sql.dataframe import DataFrame

class JavaClassificationModel(JavaPredictionModel):
class JavaClassifierParams(HasRawPredictionCol, JavaPredictorParams): ...

class JavaClassifier(JavaPredictor[JM], JavaClassifierParams, metaclass=abc.ABCMeta):
def setRawPredictionCol(self: P, value: str) -> P: ...

class JavaClassificationModel(JavaPredictionModel[T], JavaClassifierParams):
def setRawPredictionCol(self: P, value: str) -> P: ...
@property
def numClasses(self) -> int: ...

class LinearSVC(JavaEstimator[LinearSVCModel], HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization, HasThreshold, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable[LinearSVC]):
class JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, JavaClassifierParams): ...

class JavaProbabilisticClassifier(JavaClassifier[JM], JavaProbabilisticClassifierParams, metaclass=abc.ABCMeta):
def setProbabilityCol(self: P, value: str) -> P: ...
def setThresholds(self: P, value: List[float]) -> P: ...

class JavaProbabilisticClassificationModel(JavaClassificationModel[T], JavaProbabilisticClassifierParams):
def setProbabilityCol(self: P, value: str) -> P: ...
def setThresholds(self, value: List[float]) -> P: ...

class LinearSVC(JavaClassifier[LinearSVCModel], HasMaxIter, HasRegParam, HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold, JavaMLWritable, JavaMLReadable[LinearSVC]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxIter: int = ..., regParam: float = ..., tol: float = ..., rawPredictionCol: str = ..., fitIntercept: bool = ..., standardization: bool = ..., threshold: float = ..., weightCol: Optional[str] = ..., aggregationDepth: int = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxIter: int = ..., regParam: float = ..., tol: float = ..., rawPredictionCol: str = ..., fitIntercept: bool = ..., standardization: bool = ..., threshold: float = ..., weightCol: Optional[str] = ..., aggregationDepth: int = ...) -> LinearSVC: ...

class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[LinearSVCModel]):
class LinearSVCModel(JavaClassificationModel[Vector], JavaMLWritable, JavaMLReadable[LinearSVCModel]):
@property
def coefficients(self) -> Vector: ...
@property
def intercept(self) -> float: ...

class LogisticRegression(JavaEstimator[LogisticRegressionModel], HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable[LogisticRegression]):
class LogisticRegression(JavaProbabilisticClassifier[LogisticRegressionModel], HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable[LogisticRegression]):
threshold: Param[float]
family: Param[str]
lowerBoundsOnCoefficients: Param[Matrix]
Expand All @@ -50,7 +66,7 @@ class LogisticRegression(JavaEstimator[LogisticRegressionModel], HasFeaturesCol,
def setUpperBoundsOnIntercepts(self, value: Vector) -> LogisticRegression: ...
def getUpperBoundsOnIntercepts(self) -> Vector: ...

class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[LogisticRegressionModel], HasTrainingSummary):
class LogisticRegressionModel(JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[LogisticRegressionModel], HasTrainingSummary[LogisticRegressionTrainingSummary]):
@property
def coefficients(self) -> Vector: ...
@property
Expand Down Expand Up @@ -127,7 +143,7 @@ class TreeClassifierParams:
def __init__(self) -> None: ...
def getImpurity(self) -> str: ...

class DecisionTreeClassifier(JavaEstimator[DecisionTreeClassificationModel], HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable[DecisionTreeClassifier]):
class DecisionTreeClassifier(JavaProbabilisticClassifier[DecisionTreeClassificationModel], HasWeightCol, DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable[DecisionTreeClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., seed: Optional[int] = ..., weightCol: Optional[str] = ..., leafCol: str = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., seed: Optional[int] = ..., weightCol: Optional[str] = ..., leafCol: str = ...) -> DecisionTreeClassifier: ...
def setMaxDepth(self, value: int) -> DecisionTreeClassifier: ...
Expand All @@ -138,11 +154,11 @@ class DecisionTreeClassifier(JavaEstimator[DecisionTreeClassificationModel], Has
def setCacheNodeIds(self, value: bool) -> DecisionTreeClassifier: ...
def setImpurity(self, value: str) -> DecisionTreeClassifier: ...

class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[DecisionTreeClassificationModel]):
class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[DecisionTreeClassificationModel]):
@property
def featureImportances(self) -> Vector: ...

class RandomForestClassifier(JavaEstimator[RandomForestClassificationModel], HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, HasRawPredictionCol, HasProbabilityCol, RandomForestParams, TreeClassifierParams, HasCheckpointInterval, JavaMLWritable, JavaMLReadable[RandomForestClassifier]):
class RandomForestClassifier(JavaProbabilisticClassifier[RandomForestClassificationModel], HasSeed, RandomForestParams, TreeClassifierParams, HasCheckpointInterval, JavaMLWritable, JavaMLReadable[RandomForestClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., leafCol: str = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., seed: Optional[int] = ..., impurity: str = ..., numTrees: int = ..., featureSubsetStrategy: str = ..., subsamplingRate: float = ..., leafCol: str = ...) -> RandomForestClassifier: ...
def setMaxDepth(self, value: int) -> RandomForestClassifier: ...
Expand All @@ -156,7 +172,7 @@ class RandomForestClassifier(JavaEstimator[RandomForestClassificationModel], Has
def setSubsamplingRate(self, value: float) -> RandomForestClassifier: ...
def setFeatureSubsetStrategy(self, value: str) -> RandomForestClassifier: ...

class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[RandomForestClassificationModel]):
class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[RandomForestClassificationModel]):
@property
def featureImportances(self) -> Vector: ...
@property
Expand All @@ -167,7 +183,7 @@ class GBTClassifierParams(GBTParams, HasVarianceImpurity):
lossType: Param[str]
def getLossType(self) -> str: ...

class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable[GBTClassifier]):
class GBTClassifier(JavaProbabilisticClassifier[GBTClassificationModel], GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable[GBTClassifier]):
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., lossType: str = ..., maxIter: int = ..., stepSize: float = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., featureSubsetStrategy: str = ..., validationTol: float = ..., validationIndicatorCol: Optional[str] = ..., leafCol: str = ...) -> None: ...
def setParams(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., maxDepth: int = ..., maxBins: int = ..., minInstancesPerNode: int = ..., minInfoGain: float = ..., maxMemoryInMB: int = ..., cacheNodeIds: bool = ..., checkpointInterval: int = ..., lossType: str = ..., maxIter: int = ..., stepSize: float = ..., seed: Optional[int] = ..., subsamplingRate: float = ..., featureSubsetStrategy: str = ..., validationTol: float = ..., validationIndicatorCol: Optional[str] = ..., leafCol: str = ...) -> GBTClassifier: ...
def setMaxDepth(self, value: int) -> GBTClassifier: ...
Expand All @@ -182,14 +198,14 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
def setFeatureSubsetStrategy(self, value: str) -> GBTClassifier: ...
def setValidationIndicatorCol(self, value: str) -> GBTClassifier: ...

class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[GBTClassificationModel]):
class GBTClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[GBTClassificationModel]):
@property
def featureImportances(self) -> Vector: ...
@property
def trees(self) -> List[DecisionTreeRegressionModel]: ...
def evaluateEachIteration(self, dataset: DataFrame) -> List[float]: ...

class NaiveBayes(JavaEstimator[NaiveBayesModel], HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable[NaiveBayes]):
class NaiveBayes(JavaProbabilisticClassifier[NaiveBayesModel], HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable[NaiveBayes]):
smoothing: Param[float]
modelType: Param[str]
def __init__(self, *, featuresCol: str = ..., labelCol: str = ..., predictionCol: str = ..., probabilityCol: str = ..., rawPredictionCol: str = ..., smoothing: float = ..., modelType: str = ..., thresholds: Optional[List[float]] = ..., weightCol: Optional[str] = ...) -> None: ...
Expand All @@ -199,13 +215,13 @@ class NaiveBayes(JavaEstimator[NaiveBayesModel], HasFeaturesCol, HasLabelCol, Ha
def setModelType(self, value: str) -> NaiveBayes: ...
def getModelType(self) -> str: ...

class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[NaiveBayesModel]):
class NaiveBayesModel(JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[NaiveBayesModel]):
@property
def pi(self) -> Vector: ...
@property
def theta(self) -> Matrix: ...

class MultilayerPerceptronClassifier(JavaEstimator[MultilayerPerceptronClassificationModel], HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver, JavaMLWritable, JavaMLReadable[MultilayerPerceptronClassifier], HasProbabilityCol, HasRawPredictionCol):
class MultilayerPerceptronClassifier(JavaProbabilisticClassifier[MultilayerPerceptronClassificationModel], HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver, JavaMLWritable, JavaMLReadable[MultilayerPerceptronClassifier]):
layers: Param[List[int]]
blockSize: Param[int]
solver: Param[str]
Expand All @@ -221,13 +237,13 @@ class MultilayerPerceptronClassifier(JavaEstimator[MultilayerPerceptronClassific
def setInitialWeights(self, value: Vector) -> MultilayerPerceptronClassifier: ...
def getInitialWeights(self) -> Vector: ...

class MultilayerPerceptronClassificationModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable[MultilayerPerceptronClassificationModel]):
class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel[Vector], JavaMLWritable, JavaMLReadable[MultilayerPerceptronClassificationModel]):
@property
def layers(self) -> List[int]: ...
@property
def weights(self) -> Vector: ...

class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol, HasRawPredictionCol):
class OneVsRestParams(JavaClassifierParams, HasWeightCol):
classifier: Param[Estimator]
def getClassifier(self) -> Estimator[M]: ...

Expand All @@ -240,4 +256,5 @@ class OneVsRest(Estimator[OneVsRestModel], OneVsRestParams, HasParallelism, Java
class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable[OneVsRestModel], JavaMLWritable):
models: List[Transformer]
def __init__(self, models: List[Transformer]) -> None: ...
def setClassifier(self, value: Estimator[M]) -> OneVsRest: ...
def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRestModel: ...
Loading