-
Notifications
You must be signed in to change notification settings - Fork 10
/
cyk_table.py
135 lines (112 loc) · 4.4 KB
/
cyk_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Representation of a CYK table
"""
from pyformlang.cfg.parse_tree import ParseTree
class CYKTable:
"""
A CYK table
Parameters
----------
cfg : A context-free grammar
word : tuple of Terminals
The word from which we construct the CYK table
"""
def __init__(self, cfg, word):
self._cnf = cfg.to_normal_form()
self._word = word
self._productions_d = {}
self._set_productions_by_body()
self._cyk_table = {}
if not self._generates_all_terminals():
self._cyk_table[(0, len(self._word))] = set()
else:
self._set_cyk_table()
def _set_productions_by_body(self):
# Organize productions
for production in self._cnf.productions:
temp = tuple(production.body)
if temp in self._productions_d:
self._productions_d[temp].append(production.head)
else:
self._productions_d[temp] = [production.head]
def _set_cyk_table(self):
self._initialize_cyk_table()
self._propagate_in_cyk_table()
def _get_windows(self):
# The windows must in order by length
for window_size in range(2, len(self._word) + 1):
for start_window in range(len(self._word) - window_size + 1):
yield start_window, start_window + window_size
def _get_all_window_pairs(self, start_window, end_window):
for mid_window in range(start_window + 1, end_window):
for var_b in self._cyk_table[(start_window, mid_window)]:
for var_c in self._cyk_table[(mid_window, end_window)]:
yield var_b, var_c
def _propagate_in_cyk_table(self):
for start_window, end_window in self._get_windows():
for var_b, var_c in self._get_all_window_pairs(start_window,
end_window):
for var_a in self._productions_d.get((var_b.value,
var_c.value), []):
self._cyk_table[(start_window, end_window)].add(
CYKNode(var_a, var_b, var_c))
def _initialize_cyk_table(self):
for i, terminal in enumerate(self._word):
self._cyk_table[(i, i + 1)] = \
{CYKNode(x, CYKNode(terminal))
for x in self._productions_d[(terminal,)]}
for window_size in range(2, len(self._word) + 1):
for start_window in range(len(self._word) - window_size + 1):
# We use set because we do not want duplicate
# It makes iterations longer
self._cyk_table[
(start_window, start_window + window_size)] = set()
def generate_word(self):
"""
Checks is the word is generated
Returns
-------
is_generated : bool
"""
return self._cnf.start_symbol in self._cyk_table[(0, len(self._word))]
def _generates_all_terminals(self):
generate_all_terminals = True
for terminal in self._word:
if (terminal,) not in self._productions_d:
generate_all_terminals = False
return generate_all_terminals
def get_parse_tree(self):
"""
Give the parse tree associated with this CYK Table
Returns
-------
parse_tree : :class:`~pyformlang.cfg.ParseTree`
"""
if self._word and not self.generate_word():
raise DerivationDoesNotExist
if not self._word:
return CYKNode(self._cnf.start_symbol)
root = [
x
for x in self._cyk_table[(0, len(self._word))]
if x == self._cnf.start_symbol][0]
return root
class CYKNode(ParseTree):
"""A node in the CYK table"""
def __init__(self, value, left_son=None, right_son=None):
super().__init__(value)
self.value = value
self.left_son = left_son
self.right_son = right_son
if left_son is not None:
self.sons.append(left_son)
if right_son is not None:
self.sons.append(right_son)
def __eq__(self, other):
if isinstance(other, CYKNode):
return self.value == other.value
return self.value == other
def __hash__(self):
return hash(self.value)
class DerivationDoesNotExist(Exception):
"""Exception raised when the word cannot be derived"""