Skip to content

Commit

Permalink
Fix some typos and calculation of initial weights
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Apr 15, 2014
1 parent 74eefe7 commit a07ba10
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def _get_initial_weights(initial_weights, data):
if initial_weights.ndim != 1:
raise TypeError("At least one data element has "
+ initial_weights.ndim + " dimensions, which is not 1")
initial_weights = numpy.ones([initial_weights.shape[0] - 1])
initial_weights = numpy.ones([initial_weights.shape[0]])
elif type(initial_weights) == SparseVector:
initial_weights = numpy.ones([initial_weights.size - 1])
initial_weights = numpy.ones([initial_weights.size])
return initial_weights


Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
LinearModel, _linear_predictor_typecheck
LinearModel, _linear_predictor_typecheck, _get_unmangled_labeled_point_rdd
from pyspark.mllib.linalg import SparseVector
from pyspark.mllib.regression import LabeledPoint
from math import exp, log
Expand Down Expand Up @@ -135,9 +135,9 @@ class NaiveBayesModel(object):
>>> model.predict(array([1.0, 0.0]))
1.0
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {1: 0.0}),
... LabeledPoint(0.0, SparseVector(2, {1: 1.0}),
... LabeledPoint(1.0, SparseVector(2, {0: 1.0})
... LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
... LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {0: 1.0}))
... ]
>>> model = NaiveBayes.train(sc.parallelize(sparse_data))
>>> model.predict(SparseVector(2, {1: 1.0}))
Expand Down Expand Up @@ -173,7 +173,7 @@ def train(cls, data, lambda_=1.0):
@param lambda_: The smoothing parameter
"""
sc = data.context
dataBytes = _get_unmangled_double_vector_rdd(data)
dataBytes = _get_unmangled_labeled_point_rdd(data)
ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_)
return NaiveBayesModel(
_deserialize_double_vector(ans[0]),
Expand Down

0 comments on commit a07ba10

Please sign in to comment.