From 98582c4244e7413bd387624d316cf971d216a777 Mon Sep 17 00:00:00 2001 From: Mitchell Stern Date: Tue, 29 Nov 2016 09:18:25 -0800 Subject: [PATCH 1/2] Fixed slicing in Python interface --- python/_dynet.pyx | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/_dynet.pyx b/python/_dynet.pyx index a1a0ecbc4..175f6634a 100644 --- a/python/_dynet.pyx +++ b/python/_dynet.pyx @@ -622,11 +622,16 @@ cdef class Expression: #{{{ return "expression %s/%s" % (self.vindex, self.cg_version) # __getitem__ and __getslice__ in one for python 3 compatibility - def __getitem__(self, object index): - if isinstance(index, int): - return pick(self, index) - - return pickrange(self, index[0], index[1]) + def __getitem__(self, index): + assert isinstance(index, (int, slice)) + if isinstance(index, int): + return pick(self, index) + else: + if index.start is None or index.stop is None: + raise ValueError("Default start and stop indices not yet supported.") + if index.step is not None: + raise ValueError("Step sizes not yet supported.") + return pickrange(self, index.start, index.stop) cpdef scalar_value(self, recalculate=False): if self.cg_version != _cg._cg_version: raise RuntimeError("Stale Expression (created before renewing the Computation Graph).") From 45bf81faa4564ed09ca7817e504984b93e5ab4e6 Mon Sep 17 00:00:00 2001 From: Mitchell Stern Date: Tue, 6 Dec 2016 03:39:33 -0800 Subject: [PATCH 2/2] Added support for negative indices and default slice arguments --- python/_dynet.pyx | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/python/_dynet.pyx b/python/_dynet.pyx index 6ffcc5093..270f712c3 100644 --- a/python/_dynet.pyx +++ b/python/_dynet.pyx @@ -624,14 +624,41 @@ cdef class Expression: #{{{ # __getitem__ and __getslice__ in one for python 3 compatibility def __getitem__(self, index): assert isinstance(index, (int, slice)) + cdef int rows = self.c().dim().rows() + cdef int i, j if isinstance(index, int): - return pick(self, index) + i = index + if i > rows - 1: + raise IndexError("Index too large: %d > %d" % (i, rows - 1)) + if i < -rows: + raise IndexError("Index too small: %d < %d" % (i, -rows)) + if i < 0: + i += rows + return pick(self, i) else: - if index.start is None or index.stop is None: - raise ValueError("Default start and stop indices not yet supported.") + i = 0 + j = rows + if index.start is not None: + i = index.start + if i > rows - 1: + raise IndexError("Start index too large: %d > %d" % (i, rows - 1)) + if i < -rows: + raise IndexError("Start index too small: %d < %d" % (i, -rows)) + if i < 0: + i += rows + if index.stop is not None: + j = index.stop + if j > rows - 1: + raise IndexError("Stop index too large: %d > %d" % (j, rows - 1)) + if j < -rows: + raise IndexError("Stop index too small: %d < %d" % (j, -rows)) + if j < 0: + j += rows + if i >= j: + raise ValueError("Improper slice: start index must come strictly before stop index") if index.step is not None: raise ValueError("Step sizes not yet supported.") - return pickrange(self, index.start, index.stop) + return pickrange(self, i, j) cpdef scalar_value(self, recalculate=False): if self.cg_version != _cg._cg_version: raise RuntimeError("Stale Expression (created before renewing the Computation Graph).")