diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3ffed39898d84..c4f41d07800c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -48,7 +48,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for the name of family which is a description of the error distribution * to be used in the model. - * Supported options: "gaussian", "binomial", "poisson" and "gamma". + * Supported options: "gaussian", "binomial", "poisson", "gamma" and "tweedie". * Default is "gaussian". * * @group param @@ -63,10 +63,35 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getFamily: String = $(family) + /** + * Param for the power in the variance function of the Tweedie distribution which provides + * the relationship between the variance and mean of the distribution. + * Only applicable for the Tweedie family. + * (see + * Tweedie Distribution (Wikipedia)) + * Supported values: 0 and [1, Inf). + * Note that variance power 0, 1, or 2 corresponds to the Gaussian, Poisson or Gamma + * family, respectively. + * + * @group param + */ + @Since("2.2.0") + final val variancePower: DoubleParam = new DoubleParam(this, "variancePower", + "The power in the variance function of the Tweedie distribution which characterizes " + + "the relationship between the variance and mean of the distribution. " + + "Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).", + (x: Double) => x >= 1.0 || x == 0.0) + + /** @group getParam */ + @Since("2.2.0") + def getVariancePower: Double = $(variancePower) + /** * Param for the name of link function which provides the relationship * between the linear predictor and the mean of the distribution function. * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * This is used only when family is not "tweedie". The link function for the "tweedie" family + * must be specified through [[linkPower]]. * * @group param */ @@ -80,6 +105,21 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getLink: String = $(link) + /** + * Param for the index in the power link function. Only applicable for the Tweedie family. + * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt + * link, respectively. + * + * @group param + */ + @Since("2.2.0") + final val linkPower: DoubleParam = new DoubleParam(this, "linkPower", + "The index in the power link function. Only applicable for the Tweedie family.") + + /** @group getParam */ + @Since("2.2.0") + def getLinkPower: Double = $(linkPower) + /** * Param for link prediction (linear predictor) column name. * Default is not set, which means we do not output link prediction. @@ -106,11 +146,27 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if (isDefined(link)) { - require(supportedFamilyAndLinkPairs.contains( - Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + - s"with ${$(family)} family does not support ${$(link)} link function.") + if ($(family).toLowerCase == "tweedie") { + if (isSet(link)) { + logWarning("When family is tweedie, use param linkPower to specify link function. " + + "Setting param link will take no effect.") + } + } else { + if (isSet(variancePower)) { + logWarning("When family is not tweedie, setting param variancePower will take no effect.") + } + if (isSet(linkPower)) { + logWarning("When family is not tweedie, use param link to specify link function. " + + "Setting param linkPower will take no effect.") + } + if (isSet(link)) { + require(supportedFamilyAndLinkPairs.contains( + Family.fromParams(this) -> Link.fromParams(this)), + s"Generalized Linear Regression with ${$(family)} family " + + s"does not support ${$(link)} link function.") + } } + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) if (hasLinkPredictionCol) { SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) @@ -128,13 +184,15 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Generalized linear model (Wikipedia)) * specified by giving a symbolic description of the linear * predictor (link function) and a description of the error distribution (family). - * It supports "gaussian", "binomial", "poisson" and "gamma" as family. + * It supports "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. * Valid link functions for each family is listed below. The first link function of each family * is the default one. * - "gaussian" : "identity", "log", "inverse" * - "binomial" : "logit", "probit", "cloglog" * - "poisson" : "log", "identity", "sqrt" * - "gamma" : "inverse", "identity", "log" + * - "tweedie" : power link function specified through "linkPower". The default link power in + * the tweedie family is 1 - variancePower. */ @Experimental @Since("2.0.0") @@ -157,8 +215,29 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setFamily(value: String): this.type = set(family, value) setDefault(family -> Gaussian.name) + /** + * Sets the value of param [[variancePower]]. + * Used only when family is "tweedie". + * Default is 0.0, which corresponds to the "gaussian" family. + * + * @group setParam + */ + @Since("2.2.0") + def setVariancePower(value: Double): this.type = set(variancePower, value) + setDefault(variancePower -> 0.0) + + /** + * Sets the value of param [[linkPower]]. + * Used only when family is "tweedie". + * + * @group setParam + */ + @Since("2.2.0") + def setLinkPower(value: Double): this.type = set(linkPower, value) + /** * Sets the value of param [[link]]. + * Used only when family is not "tweedie". * * @group setParam */ @@ -242,13 +321,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { - val familyObj = Family.fromName($(family)) - val linkObj = if (isDefined(link)) { - Link.fromName($(link)) - } else { - familyObj.defaultLink - } - val familyAndLink = new FamilyAndLink(familyObj, linkObj) + val familyAndLink = FamilyAndLink(this) val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val instr = Instrumentation.create(this, dataset) @@ -269,7 +342,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val Instance(label, weight, features) } - val model = if (familyObj == Gaussian && linkObj == Identity) { + val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) @@ -308,7 +381,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine @Since("2.0.0") override def load(path: String): GeneralizedLinearRegression = super.load(path) - /** Set of family and link pairs that GeneralizedLinearRegression supports. */ + /** + * Set of family (except for tweedie) and link pairs that GeneralizedLinearRegression supports. + * The link function of the Tweedie family is specified through param linkPower. + */ private[regression] lazy val supportedFamilyAndLinkPairs = Set( Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, @@ -317,10 +393,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine ) /** Set of family names that GeneralizedLinearRegression supports. */ - private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + private[regression] lazy val supportedFamilyNames = + supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie" /** Set of link names that GeneralizedLinearRegression supports. */ - private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + private[regression] lazy val supportedLinkNames = + supportedFamilyAndLinkPairs.map(_._2.name).toArray private[regression] val epsilon: Double = 1E-16 @@ -369,6 +447,24 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine } } + private[regression] object FamilyAndLink { + + /** + * Constructs the FamilyAndLink object from a parameter map + */ + def apply(params: GeneralizedLinearRegressionBase): FamilyAndLink = { + val familyObj = Family.fromParams(params) + val linkObj = if ((params.getFamily.toLowerCase != "tweedie" && + params.isSet(params.link)) || (params.getFamily.toLowerCase == "tweedie" && + params.isSet(params.linkPower))) { + Link.fromParams(params) + } else { + familyObj.defaultLink + } + new FamilyAndLink(familyObj, linkObj) + } + } + /** * A description of the error distribution to be used in the model. * @@ -409,27 +505,109 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Family { /** - * Gets the [[Family]] object from its name. + * Gets the [[Family]] object based on param family and variancePower. + * If param family is set with "gaussian", "binomial", "poisson" or "gamma", + * return the corresponding object directly; otherwise, construct a Tweedie object + * according to variancePower. * - * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + * @param params the parameter map containing family name and variance power */ - def fromName(name: String): Family = { - name.toLowerCase match { + def fromParams(params: GeneralizedLinearRegressionBase): Family = { + params.getFamily.toLowerCase match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson case Gamma.name => Gamma + case "tweedie" => + params.getVariancePower match { + case 0.0 => Gaussian + case 1.0 => Poisson + case 2.0 => Gamma + case others => new Tweedie(others) + } } } } + /** + * Tweedie exponential family distribution. + * This includes the special cases of Gaussian, Poisson and Gamma. + */ + private[regression] class Tweedie(val variancePower: Double) + extends Family("tweedie") { + + override val defaultLink: Link = new Power(1.0 - variancePower) + + override def initialize(y: Double, weight: Double): Double = { + if (variancePower >= 1.0 && variancePower < 2.0) { + require(y >= 0.0, s"The response variable of $name($variancePower) family " + + s"should be non-negative, but got $y") + } else if (variancePower >= 2.0) { + require(y > 0.0, s"The response variable of $name($variancePower) family " + + s"should be positive, but got $y") + } + if (y == 0) Tweedie.delta else y + } + + override def variance(mu: Double): Double = math.pow(mu, variancePower) + + private def yp(y: Double, mu: Double, p: Double): Double = { + if (p == 0) { + math.log(y / mu) + } else { + (math.pow(y, p) - math.pow(mu, p)) / p + } + } + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + // Force y >= delta for Poisson or compound Poisson + val y1 = if (variancePower >= 1.0 && variancePower < 2.0) { + math.max(y, Tweedie.delta) + } else { + y + } + 2.0 * weight * + (y * yp(y1, mu, 1.0 - variancePower) - yp(y, mu, 2.0 - variancePower)) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + /* + This depends on the density of the Tweedie distribution. + Only implemented for Gaussian, Poisson and Gamma at this point. + */ + throw new UnsupportedOperationException("No AIC available for the tweedie family") + } + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + private[regression] object Tweedie{ + + /** Constant used in initialization and deviance to avoid numerical issues. */ + val delta: Double = 0.1 + } + /** * Gaussian exponential family distribution. * The default link for the Gaussian family is the identity link. */ - private[regression] object Gaussian extends Family("gaussian") { + private[regression] object Gaussian extends Tweedie(0.0) { - val defaultLink: Link = Identity + override val name: String = "gaussian" + + override val defaultLink: Link = Identity override def initialize(y: Double, weight: Double): Double = y @@ -515,18 +693,20 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Poisson exponential family distribution. * The default link for the Poisson family is the log link. */ - private[regression] object Poisson extends Family("poisson") { + private[regression] object Poisson extends Tweedie(1.0) { + + override val name: String = "poisson" - val defaultLink: Link = Log + override val defaultLink: Link = Log override def initialize(y: Double, weight: Double): Double = { require(y >= 0.0, "The response variable of Poisson family " + s"should be non-negative, but got $y") /* Force Poisson mean > 0 to avoid numerical instability in IRLS. - R uses y + 0.1 for initialization. See poisson()$initialize. + R uses y + delta for initialization. See poisson()$initialize. */ - math.max(y, 0.1) + math.max(y, Tweedie.delta) } override def variance(mu: Double): Double = mu @@ -544,25 +724,17 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine weight * dist.Poisson(mu).logProbabilityOf(y.toInt) }.sum() } - - override def project(mu: Double): Double = { - if (mu < epsilon) { - epsilon - } else if (mu.isInfinity) { - Double.MaxValue - } else { - mu - } - } } /** * Gamma exponential family distribution. * The default link for the Gamma family is the inverse link. */ - private[regression] object Gamma extends Family("gamma") { + private[regression] object Gamma extends Tweedie(2.0) { - val defaultLink: Link = Inverse + override val name: String = "gamma" + + override val defaultLink: Link = Inverse override def initialize(y: Double, weight: Double): Double = { require(y > 0.0, "The response variable of Gamma family " + @@ -586,16 +758,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine weight * dist.Gamma(1.0 / disp, mu * disp).logPdf(y) }.sum() + 2.0 } - - override def project(mu: Double): Double = { - if (mu < epsilon) { - epsilon - } else if (mu.isInfinity) { - Double.MaxValue - } else { - mu - } - } } /** @@ -620,25 +782,67 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Link { /** - * Gets the [[Link]] object from its name. + * Gets the [[Link]] object based on param family, link and linkPower. + * If param family is set with "tweedie", return or construct link function object + * according to linkPower; otherwise, return link function object according to link. * - * @param name link name: "identity", "logit", "log", - * "inverse", "probit", "cloglog" or "sqrt". + * @param params the parameter map containing family, link and linkPower */ - def fromName(name: String): Link = { - name.toLowerCase match { - case Identity.name => Identity - case Logit.name => Logit - case Log.name => Log - case Inverse.name => Inverse - case Probit.name => Probit - case CLogLog.name => CLogLog - case Sqrt.name => Sqrt + def fromParams(params: GeneralizedLinearRegressionBase): Link = { + if (params.getFamily.toLowerCase == "tweedie") { + params.getLinkPower match { + case 0.0 => Log + case 1.0 => Identity + case -1.0 => Inverse + case 0.5 => Sqrt + case others => new Power(others) + } + } else { + params.getLink.toLowerCase match { + case Identity.name => Identity + case Logit.name => Logit + case Log.name => Log + case Inverse.name => Inverse + case Probit.name => Probit + case CLogLog.name => CLogLog + case Sqrt.name => Sqrt + } } } } - private[regression] object Identity extends Link("identity") { + /** Power link function class */ + private[regression] class Power(val linkPower: Double) + extends Link("power") { + + override def link(mu: Double): Double = { + if (linkPower == 0.0) { + math.log(mu) + } else { + math.pow(mu, linkPower) + } + } + + override def deriv(mu: Double): Double = { + if (linkPower == 0.0) { + 1.0 / mu + } else { + linkPower * math.pow(mu, linkPower - 1.0) + } + } + + override def unlink(eta: Double): Double = { + if (linkPower == 0.0) { + math.exp(eta) + } else { + math.pow(eta, 1.0 / linkPower) + } + } + } + + private[regression] object Identity extends Power(1.0) { + + override val name: String = "identity" override def link(mu: Double): Double = mu @@ -656,7 +860,9 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) } - private[regression] object Log extends Link("log") { + private[regression] object Log extends Power(0.0) { + + override val name: String = "log" override def link(mu: Double): Double = math.log(mu) @@ -665,7 +871,9 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = math.exp(eta) } - private[regression] object Inverse extends Link("inverse") { + private[regression] object Inverse extends Power(-1.0) { + + override val name: String = "inverse" override def link(mu: Double): Double = 1.0 / mu @@ -694,7 +902,9 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) } - private[regression] object Sqrt extends Link("sqrt") { + private[regression] object Sqrt extends Power(0.5) { + + override val name: String = "sqrt" override def link(mu: Double): Double = math.sqrt(mu) @@ -727,13 +937,7 @@ class GeneralizedLinearRegressionModel private[ml] ( import GeneralizedLinearRegression._ - private lazy val familyObj = Family.fromName($(family)) - private lazy val linkObj = if (isDefined(link)) { - Link.fromName($(link)) - } else { - familyObj.defaultLink - } - private lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + private lazy val familyAndLink = FamilyAndLink(this) override protected def predict(features: Vector): Double = { val eta = predictLink(features) @@ -912,12 +1116,11 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") @transient val predictions: DataFrame = model.transform(dataset) - private[regression] lazy val family: Family = Family.fromName(model.getFamily) - private[regression] lazy val link: Link = if (model.isDefined(model.link)) { - Link.fromName(model.getLink) - } else { - family.defaultLink - } + private[regression] lazy val familyLink: FamilyAndLink = FamilyAndLink(model) + + private[regression] lazy val family: Family = familyLink.family + + private[regression] lazy val link: Link = familyLink.link /** Number of instances in DataFrame predictions. */ private[regression] lazy val numInstances: Long = predictions.count() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 828b95e544ae8..ea059858a58b7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -191,6 +191,8 @@ class GeneralizedLinearRegressionSuite assert(!glr.isDefined(glr.weightCol)) assert(glr.getRegParam === 0.0) assert(glr.getSolver == "irls") + assert(glr.getVariancePower === 0.0) + // TODO: Construct model directly instead of via fitting. val model = glr.setFamily("gaussian").setLink("identity") .fit(datasetGaussianIdentity) @@ -266,7 +268,7 @@ class GeneralizedLinearRegressionSuite assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + s"$link link and fitIntercept = $fitIntercept.") - val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) + val familyLink = FamilyAndLink(trainer) model.transform(dataset).select("features", "prediction", "linkPrediction").collect() .foreach { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => @@ -382,7 +384,7 @@ class GeneralizedLinearRegressionSuite assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " + s"$link link and fitIntercept = $fitIntercept.") - val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) + val familyLink = FamilyAndLink(trainer) model.transform(dataset).select("features", "prediction", "linkPrediction").collect() .foreach { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => @@ -454,7 +456,7 @@ class GeneralizedLinearRegressionSuite assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept.") - val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) + val familyLink = FamilyAndLink(trainer) model.transform(dataset).select("features", "prediction", "linkPrediction").collect() .foreach { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => @@ -560,7 +562,7 @@ class GeneralizedLinearRegressionSuite assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + s"$link link and fitIntercept = $fitIntercept.") - val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) + val familyLink = FamilyAndLink(trainer) model.transform(dataset).select("features", "prediction", "linkPrediction").collect() .foreach { case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => @@ -578,6 +580,169 @@ class GeneralizedLinearRegressionSuite } } + test("generalized linear regression: tweedie family against glm") { + /* + R code: + library(statmod) + df <- as.data.frame(matrix(c( + 1.0, 1.0, 0.0, 5.0, + 0.5, 1.0, 1.0, 2.0, + 1.0, 1.0, 2.0, 1.0, + 2.0, 1.0, 3.0, 3.0), 4, 4, byrow = TRUE)) + + f1 <- V1 ~ -1 + V3 + V4 + f2 <- V1 ~ V3 + V4 + + for (f in c(f1, f2)) { + for (lp in c(0, 1, -1)) + for (vp in c(1.6, 2.5)) { + model <- glm(f, df, family = tweedie(var.power = vp, link.power = lp)) + print(as.vector(coef(model))) + } + } + [1] 0.1496480 -0.0122283 + [1] 0.1373567 -0.0120673 + [1] 0.3919109 0.1846094 + [1] 0.3684426 0.1810662 + [1] 0.1759887 0.2195818 + [1] 0.1108561 0.2059430 + [1] -1.3163732 0.4378139 0.2464114 + [1] -1.4396020 0.4817364 0.2680088 + [1] -0.7090230 0.6256309 0.3294324 + [1] -0.9524928 0.7304267 0.3792687 + [1] 2.1188978 -0.3360519 -0.2067023 + [1] 2.1659028 -0.3499170 -0.2128286 + */ + val datasetTweedie = Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)), + Instance(0.5, 1.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 1.0, Vectors.dense(2.0, 1.0)), + Instance(2.0, 1.0, Vectors.dense(3.0, 3.0)) + ).toDF() + + val expected = Seq( + Vectors.dense(0, 0.149648, -0.0122283), + Vectors.dense(0, 0.1373567, -0.0120673), + Vectors.dense(0, 0.3919109, 0.1846094), + Vectors.dense(0, 0.3684426, 0.1810662), + Vectors.dense(0, 0.1759887, 0.2195818), + Vectors.dense(0, 0.1108561, 0.205943), + Vectors.dense(-1.3163732, 0.4378139, 0.2464114), + Vectors.dense(-1.439602, 0.4817364, 0.2680088), + Vectors.dense(-0.709023, 0.6256309, 0.3294324), + Vectors.dense(-0.9524928, 0.7304267, 0.3792687), + Vectors.dense(2.1188978, -0.3360519, -0.2067023), + Vectors.dense(2.1659028, -0.349917, -0.2128286)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for (fitIntercept <- Seq(false, true); + linkPower <- Seq(0.0, 1.0, -1.0); + variancePower <- Seq(1.6, 2.5)) { + val trainer = new GeneralizedLinearRegression().setFamily("tweedie") + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") + .setVariancePower(variancePower).setLinkPower(linkPower) + val model = trainer.fit(datasetTweedie) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with tweedie family, " + + s"linkPower = $linkPower, fitIntercept = $fitIntercept " + + s"and variancePower = $variancePower.") + + val familyLink = FamilyAndLink(trainer) + model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " + + s"and variancePower = $variancePower.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with tweedie family, linkPower = $linkPower, fitIntercept = $fitIntercept " + + s"and variancePower = $variancePower.") + } + idx += 1 + } + } + + test("generalized linear regression: tweedie family against glm (default power link)") { + /* + R code: + library(statmod) + df <- as.data.frame(matrix(c( + 1.0, 1.0, 0.0, 5.0, + 0.5, 1.0, 1.0, 2.0, + 1.0, 1.0, 2.0, 1.0, + 2.0, 1.0, 3.0, 3.0), 4, 4, byrow = TRUE)) + var.power <- c(0, 1, 2, 1.5) + f1 <- V1 ~ -1 + V3 + V4 + f2 <- V1 ~ V3 + V4 + for (f in c(f1, f2)) { + for (vp in var.power) { + model <- glm(f, df, family = tweedie(var.power = vp)) + print(as.vector(coef(model))) + } + } + [1] 0.4310345 0.1896552 + [1] 0.15776482 -0.01189032 + [1] 0.1468853 0.2116519 + [1] 0.2282601 0.2132775 + [1] -0.5158730 0.5555556 0.2936508 + [1] -1.2689559 0.4230934 0.2388465 + [1] 2.137852 -0.341431 -0.209090 + [1] 1.5953393 -0.1884985 -0.1106335 + */ + val datasetTweedie = Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)), + Instance(0.5, 1.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 1.0, Vectors.dense(2.0, 1.0)), + Instance(2.0, 1.0, Vectors.dense(3.0, 3.0)) + ).toDF() + + val expected = Seq( + Vectors.dense(0, 0.4310345, 0.1896552), + Vectors.dense(0, 0.15776482, -0.01189032), + Vectors.dense(0, 0.1468853, 0.2116519), + Vectors.dense(0, 0.2282601, 0.2132775), + Vectors.dense(-0.515873, 0.5555556, 0.2936508), + Vectors.dense(-1.2689559, 0.4230934, 0.2388465), + Vectors.dense(2.137852, -0.341431, -0.20909), + Vectors.dense(1.5953393, -0.1884985, -0.1106335)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + for (variancePower <- Seq(0.0, 1.0, 2.0, 1.5)) { + val trainer = new GeneralizedLinearRegression().setFamily("tweedie") + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") + .setVariancePower(variancePower) + val model = trainer.fit(datasetTweedie) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with tweedie family, " + + s"fitIntercept = $fitIntercept and variancePower = $variancePower.") + + val familyLink = FamilyAndLink(trainer) + model.transform(datasetTweedie).select("features", "prediction", "linkPrediction").collect() + .foreach { + case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"tweedie family, fitIntercept = $fitIntercept " + + s"and variancePower = $variancePower.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with tweedie family, fitIntercept = $fitIntercept " + + s"and variancePower = $variancePower.") + } + idx += 1 + } + } + } + test("glm summary: gaussian family with weight") { /* R code: @@ -1052,6 +1217,121 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } + test("glm summary: tweedie family with weight") { + /* + R code: + + library(statmod) + df <- as.data.frame(matrix(c( + 1.0, 1.0, 0.0, 5.0, + 0.5, 2.0, 1.0, 2.0, + 1.0, 3.0, 2.0, 1.0, + 0.0, 4.0, 3.0, 3.0), 4, 4, byrow = TRUE)) + + model <- glm(V1 ~ -1 + V3 + V4, data = df, weights = V2, + family = tweedie(var.power = 1.6, link.power = 0)) + summary(model) + + Deviance Residuals: + 1 2 3 4 + 0.6210 -0.0515 1.6935 -3.2539 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + V3 -0.4087 0.5205 -0.785 0.515 + V4 -0.1212 0.4082 -0.297 0.794 + + (Dispersion parameter for Tweedie family taken to be 3.830036) + + Null deviance: 20.702 on 4 degrees of freedom + Residual deviance: 13.844 on 2 degrees of freedom + AIC: NA + + Number of Fisher Scoring iterations: 11 + + residuals(model, type="pearson") + 1 2 3 4 + 0.7383616 -0.0509458 2.2348337 -1.4552090 + residuals(model, type="working") + 1 2 3 4 + 0.83354150 -0.04103552 1.55676369 -1.00000000 + residuals(model, type="response") + 1 2 3 4 + 0.45460738 -0.02139574 0.60888055 -0.20392801 + */ + val datasetWithWeight = Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)), + Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ).toDF() + + val trainer = new GeneralizedLinearRegression() + .setFamily("tweedie") + .setVariancePower(1.6) + .setLinkPower(0.0) + .setWeightCol("weight") + .setFitIntercept(false) + + val model = trainer.fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(-0.408746, -0.12125)) + val interceptR = 0.0 + val devianceResidualsR = Array(0.621047, -0.051515, 1.693473, -3.253946) + val pearsonResidualsR = Array(0.738362, -0.050946, 2.234834, -1.455209) + val workingResidualsR = Array(0.833541, -0.041036, 1.556764, -1.0) + val responseResidualsR = Array(0.454607, -0.021396, 0.608881, -0.203928) + val seCoefR = Array(0.520519, 0.408215) + val tValsR = Array(-0.785267, -0.297024) + val pValsR = Array(0.514549, 0.794457) + val dispersionR = 3.830036 + val nullDevianceR = 20.702 + val residualDevianceR = 13.844 + val residualDegreeOfFreedomNullR = 4 + val residualDegreeOfFreedomR = 2 + + val summary = model.summary + + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.solver === "irls") + } + test("glm handle collinear features") { val collinearInstances = Seq( Instance(1.0, 1.0, Vectors.dense(1.0, 2.0)), @@ -1183,7 +1463,8 @@ object GeneralizedLinearRegressionSuite { "maxIter" -> 2, // intentionally small "tol" -> 0.8, "regParam" -> 0.01, - "predictionCol" -> "myPrediction") + "predictionCol" -> "myPrediction", + "variancePower" -> 1.0) def generateGeneralizedLinearRegressionInput( intercept: Double,