Skip to content

Commit

Permalink
* Fix oov probability
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Feb 6, 2016
1 parent af8514c commit a95974a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
17 changes: 9 additions & 8 deletions spacy/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ def prefix(string):
def suffix(string):
return string[-3:]

@staticmethod
def prob(string):
return -30

@staticmethod
def cluster(string):
return 0
Expand Down Expand Up @@ -119,16 +115,16 @@ def is_stop(string):
return 0

@classmethod
def default_lex_attrs(cls):
def default_lex_attrs(cls, *args, **kwargs):
oov_prob = kwargs.get('oov_prob', -20)
return {
attrs.LOWER: cls.lower,
attrs.NORM: cls.norm,
attrs.SHAPE: cls.shape,
attrs.PREFIX: cls.prefix,
attrs.SUFFIX: cls.suffix,
attrs.CLUSTER: cls.cluster,
attrs.PROB: lambda string: -10.0,

attrs.PROB: lambda string: oov_prob,
attrs.IS_ALPHA: cls.is_alpha,
attrs.IS_ASCII: cls.is_ascii,
attrs.IS_DIGIT: cls.is_digit,
Expand Down Expand Up @@ -159,7 +155,12 @@ def default_ner_labels(cls):
@classmethod
def default_vocab(cls, package, get_lex_attr=None):
if get_lex_attr is None:
get_lex_attr = cls.default_lex_attrs()
if package.has_file('vocab', 'oov_prob'):
with package.open(('vocab', 'oov_prob')) as file_:
oov_prob = float(file_.read().strip())
get_lex_attr = cls.default_lex_attrs(oov_prob=oov_prob)
else:
get_lex_attr = cls.default_lex_attrs()
if hasattr(package, 'dir_path'):
return Vocab.from_package(package, get_lex_attr=get_lex_attr)
else:
Expand Down
29 changes: 14 additions & 15 deletions spacy/vocab.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,20 @@ cdef class Vocab:
# TODO: This is hopelessly broken. The state is transferred as just
# a temp directory! We then fail to clean this up. This method therefore
# only pretends to work. What we need to do is form an archive file.
raise NotImplementedError
#tmp_dir = tempfile.mkdtemp()
#lex_loc = path.join(tmp_dir, 'lexemes.bin')
#str_loc = path.join(tmp_dir, 'strings.json')
#vec_loc = path.join(tmp_dir, 'vec.bin')

#self.dump(lex_loc)
#with io.open(str_loc, 'w', encoding='utf8') as file_:
# self.strings.dump(file_)

#self.dump_vectors(vec_loc)
#
#state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
# self.serializer_freqs, self.data_dir)
#return (unpickle_vocab, state, None, None)
tmp_dir = tempfile.mkdtemp()
lex_loc = path.join(tmp_dir, 'lexemes.bin')
str_loc = path.join(tmp_dir, 'strings.json')
vec_loc = path.join(tmp_dir, 'vec.bin')

self.dump(lex_loc)
with io.open(str_loc, 'w', encoding='utf8') as file_:
self.strings.dump(file_)

self.dump_vectors(vec_loc)

state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
self.serializer_freqs, self.data_dir)
return (unpickle_vocab, state, None, None)

cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme
Expand Down

0 comments on commit a95974a

Please sign in to comment.