Skip to content

Commit

Permalink
add more comments to the example
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 27, 2015
1 parent 5153cff commit 0586c7b
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
if __name__ == "__main__":
sc = SparkContext(appName="SimpleTextClassificationPipeline")
sqlCtx = SQLContext(sc)

# Prepare training documents, which are labeled.
LabeledDocument = Row('id', 'text', 'label')
training = sqlCtx.inferSchema(
sc.parallelize([(0L, "a b c d e spark", 1.0),
Expand All @@ -42,6 +44,7 @@
(3L, "hadoop mapreduce", 0.0)])
.map(lambda x: LabeledDocument(*x)))

# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
tokenizer = Tokenizer() \
.setInputCol("text") \
.setOutputCol("words")
Expand All @@ -54,8 +57,10 @@
pipeline = Pipeline() \
.setStages([tokenizer, hashingTF, lr])

# Fit the pipeline to training documents.
model = pipeline.fit(training)

# Prepare test documents, which are unlabeled.
Document = Row('id', 'text')
test = sqlCtx.inferSchema(
sc.parallelize([(4L, "spark i j k"),
Expand All @@ -64,9 +69,11 @@
(7L, "apache hadoop")])
.map(lambda x: Document(*x)))

# Make predictions on test documents and print columns of interest.
prediction = model.transform(test)

prediction.registerTempTable("prediction")
selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
for row in selected.collect():
print row

sc.stop()

0 comments on commit 0586c7b

Please sign in to comment.