From ae7a704d502ca841313a29359ba7766c00841bc9 Mon Sep 17 00:00:00 2001 From: Navin Singh Date: Fri, 31 Dec 2021 21:50:49 +0530 Subject: [PATCH] new phase updateLabel implementation #58 --- .../main/java/zingg/client/ZinggOptions.java | 3 +- core/src/main/java/zingg/LabelUpdater.java | 129 ++++++++++++++++++ core/src/main/java/zingg/Labeller.java | 22 +-- core/src/main/java/zingg/ZFactory.java | 1 + 4 files changed, 145 insertions(+), 10 deletions(-) create mode 100644 core/src/main/java/zingg/LabelUpdater.java diff --git a/client/src/main/java/zingg/client/ZinggOptions.java b/client/src/main/java/zingg/client/ZinggOptions.java index 75ab1825f..c42023825 100644 --- a/client/src/main/java/zingg/client/ZinggOptions.java +++ b/client/src/main/java/zingg/client/ZinggOptions.java @@ -12,7 +12,8 @@ public enum ZinggOptions { FIND_TRAINING_DATA("findTrainingData"), LABEL("label"), LINK("link"), - GENERATE_DOCS("generateDocs"); + GENERATE_DOCS("generateDocs"), + UPDATE_LABEL("updateLabel"); private String value; diff --git a/core/src/main/java/zingg/LabelUpdater.java b/core/src/main/java/zingg/LabelUpdater.java new file mode 100644 index 000000000..1773db978 --- /dev/null +++ b/core/src/main/java/zingg/LabelUpdater.java @@ -0,0 +1,129 @@ +package zingg; + +import java.util.List; +import java.util.Scanner; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; + +import zingg.client.ZinggClientException; +import zingg.client.ZinggOptions; +import zingg.client.pipe.Pipe; +import zingg.client.util.ColName; +import zingg.client.util.ColValues; +import zingg.util.DSUtil; +import zingg.util.PipeUtil; + +public class LabelUpdater extends Labeller { + protected static String name = "zingg.LabelUpdater"; + public static final Log LOG = LogFactory.getLog(LabelUpdater.class); + + public LabelUpdater() { + setZinggOptions(ZinggOptions.UPDATE_LABEL); + } + + public void execute() throws ZinggClientException { + try { + LOG.info("Reading inputs for updateLabelling phase ..."); + Dataset markedRecords = PipeUtil.read(spark, false, false, PipeUtil.getTrainingDataMarkedPipe(args)); + processRecordsCli(markedRecords); + LOG.info("Finished updataLabelling phase"); + } catch (Exception e) { + e.printStackTrace(); + throw new ZinggClientException(e.getMessage()); + } + } + + public void processRecordsCli(Dataset lines) throws ZinggClientException { + LOG.info("Processing Records for CLI updateLabelling"); + getMarkedRecordsStat(lines); + printMarkedRecordsStat(); + if (lines == null || lines.count() == 0) { + LOG.info("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data."); + return; + } + + List displayCols = DSUtil.getFieldDefColumns(lines, args, false); + try { + int matchFlag; + Dataset updatedRecords = null; + Dataset recordsToUpdate = lines; + int selectedOption = -1; + String postMsg; + + Scanner sc = new Scanner(System.in); + do { + System.out.print("\tPlease enter the cluster id (or '9' to exit): "); + String cluster_id = sc.next(); + if (cluster_id.equals("9")) { + LOG.info("User has exit in the middle. Updating the records."); + break; + } + Dataset currentPair = lines.filter(lines.col(ColName.CLUSTER_COLUMN).equalTo(cluster_id)); + if (currentPair.isEmpty()) { + System.out.println("\tInvalid cluster id. Enter '9' to exit"); + continue; + } + + matchFlag = currentPair.head().getAs(ColName.MATCH_FLAG_COL); + postMsg = String.format("\tCurrent Match type for the above pair is %d\n", matchFlag); + selectedOption = displayRecordsAndGetUserInput(DSUtil.select(currentPair, displayCols), "", postMsg); + updateLabellerStat(selectedOption, matchFlag); + printMarkedRecordsStat(); + if (selectedOption == 9) { + LOG.info("User has quit in the middle. Updating the records."); + break; + } + recordsToUpdate = recordsToUpdate + .filter(recordsToUpdate.col(ColName.CLUSTER_COLUMN).notEqual(cluster_id)); + if (updatedRecords != null) { + updatedRecords = updatedRecords + .filter(updatedRecords.col(ColName.CLUSTER_COLUMN).notEqual(cluster_id)); + } + updatedRecords = updateRecords(selectedOption, currentPair, updatedRecords); + } while (selectedOption != 9); + + if (updatedRecords != null) { + updatedRecords = updatedRecords.union(recordsToUpdate); + } + writeLabelledOutput(updatedRecords, SaveMode.Overwrite); + sc.close(); + LOG.info("Processing finished."); + } catch (Exception e) { + if (LOG.isDebugEnabled()) { + e.printStackTrace(); + } + LOG.warn("An error has occured while Updating Label. " + e.getMessage()); + throw new ZinggClientException(e.getMessage()); + } + return; + } + + private void updateLabellerStat(int selectedOption, int existingType) { + --totalCount; + if (existingType == ColValues.MATCH_TYPE_MATCH) { + --positivePairsCount; + } + else if (existingType == ColValues.MATCH_TYPE_NOT_A_MATCH) { + --negativePairsCount; + } + else if (existingType == ColValues.MATCH_TYPE_NOT_SURE) { + --notSurePairsCount; + } + updateLabellerStat(selectedOption); + } + + void writeLabelledOutput(Dataset records, SaveMode mode) { + if (records == null) { + LOG.warn("No marked record has been updated."); + return; + } + Pipe p = PipeUtil.getTrainingDataMarkedPipe(args); + p.setMode(mode); + PipeUtil.write(records, args, ctx, p); + } +} \ No newline at end of file diff --git a/core/src/main/java/zingg/Labeller.java b/core/src/main/java/zingg/Labeller.java index aee15c750..79dd7769f 100644 --- a/core/src/main/java/zingg/Labeller.java +++ b/core/src/main/java/zingg/Labeller.java @@ -59,10 +59,7 @@ public Dataset getUnmarkedRecords() throws ZinggClientException { unmarkedRecords = unmarkedRecords.join(markedRecords, unmarkedRecords.col(ColName.CLUSTER_COLUMN).equalTo(markedRecords.col(ColName.CLUSTER_COLUMN)), "left_anti"); - positivePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_MATCH)).count() / 2; - negativePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_A_MATCH)).count() / 2; - notSurePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_SURE)).count() / 2; - totalCount = markedRecords.count() / 2; + getMarkedRecordsStat(markedRecords); } } catch (Exception e) { LOG.warn("No unmarked record for labelling"); @@ -70,6 +67,13 @@ public Dataset getUnmarkedRecords() throws ZinggClientException { return unmarkedRecords; } + protected void getMarkedRecordsStat(Dataset markedRecords) { + positivePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_MATCH)).count() / 2; + negativePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_A_MATCH)).count() / 2; + notSurePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_SURE)).count() / 2; + totalCount = markedRecords.count() / 2; + } + public void processRecordsCli(Dataset lines) throws ZinggClientException { LOG.info("Processing Records for CLI Labelling"); printMarkedRecordsStat(); @@ -105,6 +109,7 @@ public void processRecordsCli(Dataset lines) throws ZinggClientException { selected_option = displayRecordsAndGetUserInput(DSUtil.select(currentPair, displayCols), msg1, msg2); updateLabellerStat(selected_option); + printMarkedRecordsStat(); if (selected_option == 9) { LOG.info("User has quit in the middle. Updating the records."); break; @@ -124,7 +129,7 @@ public void processRecordsCli(Dataset lines) throws ZinggClientException { } - private int displayRecordsAndGetUserInput(Dataset records, String preMessage, String postMessage) { + protected int displayRecordsAndGetUserInput(Dataset records, String preMessage, String postMessage) { //System.out.println(); System.out.println(preMessage); records.show(false); @@ -134,7 +139,7 @@ private int displayRecordsAndGetUserInput(Dataset records, String preMessag return selection; } - private Dataset updateRecords(int matchValue, Dataset newRecords, Dataset updatedRecords) { + protected Dataset updateRecords(int matchValue, Dataset newRecords, Dataset updatedRecords) { newRecords = newRecords.withColumn(ColName.MATCH_FLAG_COL, functions.lit(matchValue)); if (updatedRecords == null) { updatedRecords = newRecords; @@ -200,7 +205,7 @@ int readCliInput() { return selection; } - private void updateLabellerStat(int selected_option) { + protected void updateLabellerStat(int selected_option) { ++totalCount; if (selected_option == ColValues.MATCH_TYPE_MATCH) { ++positivePairsCount; @@ -211,10 +216,9 @@ else if (selected_option == ColValues.MATCH_TYPE_NOT_A_MATCH) { else if (selected_option == ColValues.MATCH_TYPE_NOT_SURE) { ++notSurePairsCount; } - printMarkedRecordsStat(); } - private void printMarkedRecordsStat() { + protected void printMarkedRecordsStat() { String msg = String.format( "\tLabelled pairs so far : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", positivePairsCount, totalCount, negativePairsCount, totalCount, notSurePairsCount, totalCount); diff --git a/core/src/main/java/zingg/ZFactory.java b/core/src/main/java/zingg/ZFactory.java index 38f1dbe14..2a76c6d9b 100644 --- a/core/src/main/java/zingg/ZFactory.java +++ b/core/src/main/java/zingg/ZFactory.java @@ -20,6 +20,7 @@ public ZFactory() {} zinggers.put(ZinggOptions.TRAIN_MATCH, TrainMatcher.name); zinggers.put(ZinggOptions.LINK, Linker.name); zinggers.put(ZinggOptions.GENERATE_DOCS, Documenter.name); + zinggers.put(ZinggOptions.UPDATE_LABEL, LabelUpdater.name); } public IZingg get(ZinggOptions z) throws InstantiationException, IllegalAccessException, ClassNotFoundException {