diff --git a/.directory b/.directory new file mode 100644 index 0000000..8c80f0e --- /dev/null +++ b/.directory @@ -0,0 +1,3 @@ +[Dolphin] +Timestamp=2017,7,29,0,16,50 +Version=3 diff --git a/.gitignore b/.gitignore index c889832..8dd7562 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,14 @@ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# Bin and txt files +*.bin +*.txt +*.npy + + + # User-specific stuff: .idea/**/workspace.xml .idea/**/tasks.xml diff --git a/Dockerfile b/Dockerfile index 2833c7d..7657dab 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,6 @@ ADD . $HOME_DIR/word2vec RUN apt-get -y update && \ apt-get -y install \ - python3-pip \ language-pack-en \ vim \ libopenblas-dev @@ -60,12 +59,6 @@ ENV LANG en_US.UTF-8 ENV LANGUAGE en_US:en ENV LC_ALL en_US.UTF-8 -############################################################ -# Exposing ports -############################################################ - -EXPOSE 9300 9200 2181 9092 - ############################################################ # Running the uberjar ############################################################ diff --git a/conf/parser-conf.properties b/conf/parser-conf.properties deleted file mode 100644 index 96d2151..0000000 --- a/conf/parser-conf.properties +++ /dev/null @@ -1,9 +0,0 @@ -# Settings for parsing the input text - -input.data.file=data/deploy-sample.txt -column.delimiter=\\*\\*\\* - -selected.cloumns.indices= 4 -columns.size=16 - - diff --git a/conf/word2vec-default.properties b/conf/word2vec-default.properties index 9f747a3..b2db048 100644 --- a/conf/word2vec-default.properties +++ b/conf/word2vec-default.properties @@ -1,11 +1,20 @@ - - +# path to the text corpus (text corpus should be a txt file where each line represents a document). input.corpus.path=corpus/sample-data.txt -output.model.save.path=model/model-alpha.bin +# path (name) to the word2vec model to be saved. +output.model.save.path=model/model-v-0.1.0-alpha.bin +# min frequency of words to be used in training (words with less frequency than this will be dropped off the vocabulary). min.word.frequency=2 -number.of.iterations=100 -layer.size=300 + +# number of training epochs. +number.of.iterations=10 + +# dimension of the word vectors. +layer.size=250 + +# window size for choosing the context of a word. window.size=5 -learning.rate=0.01 \ No newline at end of file + +# learning rate for the algorithm. +learning.rate=0.015 diff --git a/docker-compose.yml b/docker-compose.yml index f1b8108..c50e7b8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,10 +1,11 @@ version: '2' services: word2vec_trainer: - build: ./ + #image: registry.gitlab.com/hosseinabedi/meliora:development + build: . container_name: word2wec volumes: - - ./conf:/home/badger/conf - - ./log:/home/badger/log - - .corpus:/home/badger/corpus + - ./conf:/home/word2vec/conf + - ./log:/home/word2vec/log + - ./corpus:/home/word2vec/corpus diff --git a/meliora.iml b/meliora.iml deleted file mode 100644 index ba733e8..0000000 --- a/meliora.iml +++ /dev/null @@ -1,322 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/model/.keep b/model/.keep new file mode 100644 index 0000000..e69de29 diff --git a/model/model-alpha.bin b/model/model-alpha.bin deleted file mode 100644 index 46aeb51..0000000 Binary files a/model/model-alpha.bin and /dev/null differ diff --git a/src/main/java/Cleaner.java b/src/main/java/Cleaner.java new file mode 100644 index 0000000..3d3c0c0 --- /dev/null +++ b/src/main/java/Cleaner.java @@ -0,0 +1,151 @@ +/* + * To change this license header, choose License Headers in Project Properties. + * To change this template file, choose Tools | Templates + * and open the template in the editor. + */ + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.List; +import java.util.StringTokenizer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * + * @author behnam + */ +public class Cleaner { + + private final String punctuationPath = "stoplists/Cleaner/Punctuations.txt"; + private final String conjPath = "stoplists/Persian/CONJ.txt"; + private final String detPath = "stoplists/Persian/DET.txt"; + private final String pPath = "stoplists/Persian/P.txt"; + private final String postpPath = "stoplists/Persian/POSTP.txt"; + private final String proPath = "stoplists/Persian/PRO.txt"; + private final String stopwordPath = "stoplists/Persian/persian.txt"; + + private List punctuations; + private List conj; + private List det; + private List p; + private List postp; + private List pro; + private List stopword; + + public static final Pattern RTL_CHARACTERS = Pattern.compile("[\u0600-\u06FF\u0750-\u077F\u0590-\u05FF\uFE70-\uFEFF]"); + + public Cleaner() throws FileNotFoundException, UnsupportedEncodingException, IOException { + punctuations = initialize(punctuationPath); +// conj = initialize(conjPath); +// det = initialize(detPath); +// p = initialize(pPath); +// postp = initialize(postpPath); +// pro = initialize(proPath); + stopword = initialize(stopwordPath); + } + + private List initialize(String path) throws FileNotFoundException, UnsupportedEncodingException, IOException{ + List result = new ArrayList<>(); + File file = new File(path); + BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF8")); + String line = in.readLine(); + while(line != null){ + result.add(line); + line = in.readLine(); + } + return result; + } + + private String removeUnwantedTokens(String text, List list){ + StringTokenizer tokenizer = new StringTokenizer(text); + String result = ""; + while(tokenizer.hasMoreTokens()){ + String token = tokenizer.nextToken(); + if(!contains(list, token)){ + result += " " + token; + } + } + return result.trim(); + } + + private boolean contains(List list, String word){ + for(int i=0;i intersectCollections(List> topicList){ + + Collection topics = new ArrayList<>(); + Collection all = new ArrayList<>(); + for (Collection coll:topicList){ + for(String word:coll){ + + + } + } + + Set allTopicSet = new HashSet(all); + + System.out.println(allTopicSet); + + return topics; + } + + Collection unionCollections(List> topicList){ + + Collection topics = new ArrayList<>(); + + Collection all = new ArrayList<>(); + for (Collection coll:topicList){ + all.addAll(coll); + } + + + + Set allTopicSet = new HashSet(all); + + System.out.println(allTopicSet); + + return topics; + } + + + private Collection getTopN(String word, Word2Vec model, int n){ + return model.wordsNearest(word, n); + } + + + private String [] tokenize(String text) throws IOException { + Cleaner cleaner = new Cleaner(); + return cleaner.splitAtSpaces(cleaner.clean(text)); + } + + private INDArray cosMul(String document, Word2Vec model, int vectorSize) throws IOException { + + String[] tokens = tokenize(document); + INDArray baseVector = Nd4j.ones(vectorSize); + + for (String token : tokens) { + if (model.hasWord(token)) { + + baseVector = baseVector.addRowVector(Nd4j.create(model.getWordVector(token))); + + } + } + return baseVector; + } + + + + private void munchText () throws IOException { + + Cleaner cleaner = new Cleaner(); + + + while (true) { + System.out.print("Do you want to play (Y/N) ?\n"); + Scanner ans = new Scanner(System.in); + String answer = ans.nextLine(); + List> topicList = new ArrayList>(); + Collection similarWords = new ArrayList<>(); + + if (answer.equalsIgnoreCase("Y")) { + Scanner input1 = new Scanner(System.in); + Scanner input2 = new Scanner(System.in); + + System.out.println("Give me the document. \n"); + String document = input1.nextLine().trim(); + + System.out.println("Give me the tags. \n"); + String tags = input2.nextLine().trim(); + String [] tagList = cleaner.splitAtSpaces(tags); + + INDArray aggregation = cosMul(document, model, 250); + + + + //System.out.println(aggregation); + + for (String tag:tagList){ + if (model.hasWord(tag)) { + INDArray vector = Nd4j.create(model.getWordVector(tag)); + + double similarity = Transforms.cosineSim(vector, aggregation); + + System.out.println(tag + ": " + similarity); + + } + } + + + } else if (answer.equalsIgnoreCase("N")) { + System.out.print("Thank you, all done!"); + break; + } else { + System.out.print("Try again with (Y/N) only !"); + } + + } + } + + + public static void main(String[] args) throws Exception { + + Explorer explorer = new Explorer(); + explorer.munchText(); + } +} diff --git a/src/main/java/Trainer.java b/src/main/java/Trainer.java index 7f56ed6..9224adf 100644 --- a/src/main/java/Trainer.java +++ b/src/main/java/Trainer.java @@ -1,7 +1,8 @@ import java.io.*; import java.util.Properties; - import org.apache.log4j.Logger; +import org.apache.parquet.format.FileMetaData; +import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; @@ -39,8 +40,7 @@ private Trainer(){ InputStream inputStream = new FileInputStream(configFile); Properties props = new Properties(); props.load(inputStream); - - this.inputCorpusPath = props.getProperty("input.corpus.path"); + this.inputCorpusPath = new File(props.getProperty("input.corpus.path").trim()).getAbsolutePath(); this.modelSavePath = props.getProperty("output.model.save.path"); this.minWordFrequency = Integer.parseInt(props.getProperty("min.word.frequency").trim()); this.iterations = Integer.parseInt(props.getProperty("number.of.iterations").trim()); @@ -63,6 +63,7 @@ private void saveModel(String filePath, Word2Vec model){ WordVectorSerializer.writeWord2VecModel(model, filePath); + } private Word2Vec loadModel(String filePath){ diff --git a/start.sh b/start.sh index c932d46..f049fcc 100644 --- a/start.sh +++ b/start.sh @@ -2,5 +2,5 @@ echo "Kicking off!" -java -cp target/meliora-1.0-alpha-jar-with-dependencies.jar Trainer +java -Xms5g -Xmx63g -cp target/meliora-1.0-alpha-jar-with-dependencies.jar Trainer diff --git a/stoplists/README b/stoplists/README new file mode 100644 index 0000000..9757a34 --- /dev/null +++ b/stoplists/README @@ -0,0 +1,3 @@ +English stoplist is the standard Mallet stoplist. + +German, French, Finnish are borrowed from http://www.ranks.nl.