diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4759f5fe783ad..1731253df42a8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -425,7 +425,7 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. - >>> df.sample(False, 0.5, 97).count() + >>> df.sample(False, 0.5, 42).count() 1 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction @@ -433,6 +433,22 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + def randomSplit(self, weights, seed=None): + """Randomly splits this :class:`DataFrame` with the provided weights. + + >>> splits = df4.randomSplit([1.0, 2.0], 24) + >>> splits[0].count() + 1 + + >>> splits[1].count() + 3 + """ + for w in weights: + assert w >= 0.0, "Negative weight value: %s" % w + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] + @property def dtypes(self): """Returns all column names and their data types as a list. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 47cb0cbd586e4..74512ab48047d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -728,7 +728,7 @@ class DataFrame private[sql]( /** * Randomly splits this [[DataFrame]] with the provided weights. * - * @param weights weights for splits, will be normalized if they don't sum to 1 + * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. * @group dfops */ @@ -743,13 +743,24 @@ class DataFrame private[sql]( /** * Randomly splits this [[DataFrame]] with the provided weights. * - * @param weights weights for splits, will be normalized if they don't sum to 1 + * @param weights weights for splits, will be normalized if they don't sum to 1. * @group dfops */ def randomSplit(weights: Array[Double]): Array[DataFrame] = { randomSplit(weights, Utils.random.nextLong) } + /** + * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * @group dfops + */ + def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { + randomSplit(weights.toArray, seed) + } + /** * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of