Skip to content

Commit

Permalink
Merge pull request #313 from navinrathore/PythonPhases
Browse files Browse the repository at this point in the history
Updates in Python classes and 'assessModel' python phase
  • Loading branch information
sonalgoyal authored Jun 9, 2022
2 parents 9498e97 + e802117 commit 031ed56
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 58 deletions.
28 changes: 27 additions & 1 deletion client/src/main/java/zingg/client/Arguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import com.fasterxml.jackson.annotation.JsonSetter;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.module.scala.DefaultScalaModule;
import com.fasterxml.jackson.core.json.JsonWriteFeature;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -171,10 +173,33 @@ public static final Arguments createArgumentsFromJSON(String filePath, String ph
} catch (Exception e) {
//e.printStackTrace();
throw new ZinggClientException("Unable to parse the configuration at " + filePath +
". The error is " + e.getMessage());
". The error is " + e.getMessage(), e);
}
}

/**
* Write arguments to a json file
*
* @param filePath
* json file where arguments shall be written to
* @return Arguments object
* @throws ZinggClientException
* in case there is an error in writing to file
*/
public static final void writeArgumentsToJSON(String filePath, Arguments args)
throws ZinggClientException {
try {
ObjectMapper mapper = new ObjectMapper();
mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.getFactory().configure(JsonWriteFeature.QUOTE_FIELD_NAMES.mappedFeature(),true);
LOG.warn("Arguments are written to file: " + filePath);
mapper.writeValue(new File(filePath), args);
} catch (Exception e) {
throw new ZinggClientException("Unable to write the configuration to " + filePath +
". The error is " + e.getMessage(), e);
}
}

public static void checkValid(Arguments args, String phase) throws ZinggClientException {
if (phase.equals("train") || phase.equals("match") || phase.equals("trainMatch") || phase.equals("link")) {
checkIsValid(args);
Expand Down Expand Up @@ -635,6 +660,7 @@ public void setBlockSize(long blockSize){
this.blockSize = blockSize;
}

@JsonIgnore
public String[] getPipeNames() {
Pipe[] input = this.getData();
String[] sourceNames = new String[input.length];
Expand Down
5 changes: 4 additions & 1 deletion client/src/main/java/zingg/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,8 @@ public Dataset<Row> getMarkedRecords() {
return zingg.getMarkedRecords();
}


public Dataset<Row> getUnMarkedRecords() {
return zingg.getUnMarkedRecords();
}

}
41 changes: 36 additions & 5 deletions client/src/main/java/zingg/client/FieldDefinition.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.sql.types.DataType;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;


/**
Expand All @@ -37,7 +38,10 @@ public class FieldDefinition implements

public static final Log LOG = LogFactory.getLog(FieldDefinition.class);

@JsonDeserialize(using = MatchTypeDeserializer.class) public List<MatchType> matchType;
@JsonDeserialize(using = MatchTypeDeserializer.class)
@JsonSerialize(using = MatchTypeSerializer.class)
public List<MatchType> matchType;

@JsonSerialize(using = DataTypeSerializer.class)
public DataType dataType;
public String fieldName;
Expand Down Expand Up @@ -163,6 +167,33 @@ public void serialize(DataType dType, JsonGenerator jsonGenerator,
}
}

public static class MatchTypeSerializer extends StdSerializer<List<MatchType>> {
public MatchTypeSerializer() {
this(null);
}

public MatchTypeSerializer(Class<List<MatchType>> t) {
super(t);
}

@Override
public void serialize(List<MatchType> matchType, JsonGenerator jsonGen, SerializerProvider provider)
throws IOException, JsonProcessingException {
try {
jsonGen.writeObject(getStringFromMatchType(matchType));
LOG.debug("Serializing custom type");
} catch (ZinggClientException e) {
throw new IOException(e);
}
}

public static String getStringFromMatchType(List<MatchType> matchType) throws ZinggClientException {
return String.join(",", matchType.stream()
.map(p -> p.value())
.collect(Collectors.toList()));
}
}

public static class MatchTypeDeserializer extends StdDeserializer<List<MatchType>> {
private static final long serialVersionUID = 1L;

Expand Down
2 changes: 2 additions & 0 deletions client/src/main/java/zingg/client/IZingg.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public void init(Arguments args, String license)

public Dataset<Row> getMarkedRecords();

public Dataset<Row> getUnMarkedRecords();

public Long getMarkedRecordsStat(Dataset<Row> markedRecords, long value);

public Long getMatchedMarkedRecordsStat(Dataset<Row> markedRecords);
Expand Down
52 changes: 51 additions & 1 deletion client/src/test/java/zingg/client/TestArguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -15,6 +16,9 @@
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;

import zingg.client.pipe.Format;
import zingg.client.pipe.Pipe;

public class TestArguments {

private static final String KEY_HEADER = "header";
Expand Down Expand Up @@ -241,4 +245,50 @@ public void testMatchTypeWrong() {


}
}

@Test
public void testWriteArgumentObjectToJSONFile() {
Arguments args = new Arguments();
try {
FieldDefinition fname = new FieldDefinition();
fname.setFieldName("fname");
fname.setDataType("\"string\"");
fname.setMatchType(Arrays.asList(MatchType.EXACT, MatchType.FUZZY, MatchType.PINCODE));
//fname.setMatchType(Arrays.asList(MatchType.EXACT));
fname.setFields("fname");
FieldDefinition lname = new FieldDefinition();
lname.setFieldName("lname");
lname.setDataType("\"string\"");
lname.setMatchType(Arrays.asList(MatchType.FUZZY));
lname.setFields("lname");
args.setFieldDefinition(Arrays.asList(fname, lname));

Pipe inputPipe = new Pipe();
inputPipe.setName("test");
inputPipe.setFormat(Format.CSV);
inputPipe.setProp("location", "examples/febrl/test.csv");
args.setData(new Pipe[] {inputPipe});

Pipe outputPipe = new Pipe();
outputPipe.setName("output");
outputPipe.setFormat(Format.CSV);
outputPipe.setProp("location", "examples/febrl/output.csv");
args.setOutput(new Pipe[] {outputPipe});

args.setBlockSize(400L);
args.setCollectMetrics(true);
args.setModelId("500");
Arguments.writeArgumentsToJSON("configFromArgObject.json", args);

//reload the same config file to check if deserialization is successful
Arguments newArgs = Arguments.createArgumentsFromJSON("configFromArgObject.json", "test");
assertEquals(newArgs.getModelId(), "500", "Model id is different");
assertEquals(newArgs.getBlockSize(), 400L, "Block size is different");
assertEquals(newArgs.getFieldDefinition().get(0).getFieldName(), "fname", "Field Definition[0]'s name is different");
String expectedMatchType = "[EXACT, FUZZY, PINCODE]";
assertEquals(newArgs.getFieldDefinition().get(0).getMatchType().toString(), expectedMatchType);
} catch (Exception | ZinggClientException e) {
e.printStackTrace();
}
}
}
26 changes: 26 additions & 0 deletions client/src/test/java/zingg/client/TestFieldDefinition.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package zingg.client;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.Arrays;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;

public class TestFieldDefinition {
public static final Log LOG = LogFactory.getLog(TestFieldDefinition.class);

@Test
public void testConvertAListOFMatchTypesIntoString() {
try {
List<MatchType> matchType = Arrays.asList(MatchType.EMAIL, MatchType.FUZZY, MatchType.NULL_OR_BLANK);
String expectedString = "EMAIL,FUZZY,NULL_OR_BLANK";
String strMatchType = FieldDefinition.MatchTypeSerializer.getStringFromMatchType(matchType);
assertEquals(expectedString, strMatchType);
} catch (Exception | ZinggClientException e) {
e.printStackTrace();
}
}
}
9 changes: 9 additions & 0 deletions core/src/main/java/zingg/ZinggBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ public Dataset<Row> getMarkedRecords() {
return null;
}

public Dataset<Row> getUnMarkedRecords() {
try {
return PipeUtil.read(spark, false, false, PipeUtil.getTrainingDataUnmarkedPipe(args));
} catch (ZinggClientException e) {
LOG.warn("No unmarked record");
}
return null;
}

public Long getMarkedRecordsStat(Dataset<Row> markedRecords, long value) {
return markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(value)).count() / 2;
}
Expand Down
74 changes: 44 additions & 30 deletions python/phases/assessModel.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,51 @@
from zingg import *
from pyspark.sql import DataFrame
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
import sys
from IPython.display import display

args = Arguments()
fname = FieldDefinition("fname","\"string\"",[sc._jvm.zingg.client.MatchType.FUZZY])
lname = FieldDefinition("lname","\"string\"",[sc._jvm.zingg.client.MatchType.FUZZY])
fieldDef = [fname, lname]
options = sc._jvm.zingg.client.ClientOptions(["--phase", "label", "--conf", "dummy", "--license", "dummy", "--email", "[email protected]"])
inputPipe = Pipe("test", "csv")
inputPipe.addProperty("location", "examples/febrl/test.csv")
args.setData(inputPipe)
args.setModelId("100")
args.setZinggDir("models")
args.setNumPartitions(4)
args.setLabelDataSampleSize(0.5)
args.setFieldDefinition(fieldDef)
print(args.getArgs)
#Zingg execution for the given phase
client = Client(args, options)
client.init()
client.execute()
jMarkedDF = client.getMarkedRecords()
print(jMarkedDF)
markedDF = DataFrame(jMarkedDF, sqlContext)
print(markedDF)
pMarkedDF = markedDF.toPandas()
display(pMarkedDF)

#marked = client.getMarkedRecordsStat(mark, value)
#matched_marked = client.getMatchedMarkedRecordsStat(mark)
#unmatched_marked = client.getUnmatchedMarkedRecordsStat(mark)
#unsure_marked = client.getUnsureMarkedRecordsStat(mark)
logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger("zingg.assessModel")

def main():
LOG.info("Phase AssessModel starts")

#excluding argv[0] that is nothing but the current executable file
options = ClientOptions(sys.argv[1:])
options.setPhase("label")
arguments = Arguments.createArgumentsFromJSON(options.getConf(), options.getPhase())
client = Zingg(arguments, options)
client.init()

pMarkedDF = client.getPandasDfFromDs(client.getMarkedRecords())
pUnMarkedDF = client.getPandasDfFromDs(client.getUnMarkedRecords())

total_marked = pMarkedDF.shape[0]
total_unmarked = pUnMarkedDF.shape[0]
matched_marked = client.getMatchedMarkedRecordsStat()
unmatched_marked = client.getUnmatchedMarkedRecordsStat()
unsure_marked = client.getUnsureMarkedRecordsStat()

LOG.info("")
LOG.info("No. of Records Marked : %d", total_marked)
LOG.info("No. of Records UnMarked : %d", total_unmarked)
LOG.info("No. of Matches : %d", matched_marked)
LOG.info("No. of Non-Matches : %d", unmatched_marked)
LOG.info("No. of Not Sure : %d", unsure_marked)
LOG.info("")
plotConfusionMatrix(pMarkedDF)

LOG.info("Phase AssessModel ends")

def plotConfusionMatrix(pMarkedDF):
#As no model is yet created and Zingg is still learning, removing the records with prediciton = -1
pMarkedDF.drop(pMarkedDF[pMarkedDF[ColName.PREDICTION_COL] == -1].index, inplace=True)

confusion_matrix = pd.crosstab(pMarkedDF[ColName.MATCH_FLAG_COL], pMarkedDF[ColName.PREDICTION_COL], rownames=['Actual'], colnames=['Predicted'])
confusion_matrix = confusion_matrix / 2
sn.heatmap(confusion_matrix, annot=True)
plt.show()

if __name__ == "__main__":
main()
Loading

0 comments on commit 031ed56

Please sign in to comment.