Skip to content

Commit

Permalink
add python3 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
hayounav committed Oct 27, 2016
1 parent 068f154 commit 4efe677
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 27 deletions.
6 changes: 3 additions & 3 deletions python/dynet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
elif '--dynet-gpu' in sys.argv: # the python gpu switch.
sys.argv.remove('--dynet-gpu')
def print_graphviz(**kwarge):
print "Run with --dynet-viz to get the visualization behavior."
print("Run with --dynet-viz to get the visualization behavior.")
from _gdynet import *
elif '--dynet-gpus' in sys.argv or '--dynet-gpu-ids' in sys.argv: # but using the c++ gpu switches suffices to trigger gpu.
def print_graphviz(**kwarge):
print "Run with --dynet-viz to get the visualization behavior."
print("Run with --dynet-viz to get the visualization behavior.")
from _gdynet import *
else:
def print_graphviz(**kwarge):
print "Run with --dynet-viz to get the visualization behavior."
print("Run with --dynet-viz to get the visualization behavior.")
from _dynet import *

__version__ = 2.0
37 changes: 24 additions & 13 deletions python/dynet.pyx
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# on numpy arrays, see: https://github.com/cython/cython/wiki/tutorials-NumpyPointerToC

from __future__ import print_function
import sys
from cython.operator cimport dereference as deref
from libc.stdlib cimport malloc, free
import numpy as np
import cPickle as pickle

# python3 pickle already uses the c implementaion
try:
import cPickle as pickle
except ImportError:
import pickle

import os.path
# TODO:
# - set random seed (in DYNET)
Expand Down Expand Up @@ -36,7 +42,7 @@ cimport dynet
cdef init(random_seed=None):
cdef int argc = len(sys.argv)
cdef char** c_argv
args = [bytes(x) for x in sys.argv]
args = [bytes(x, encoding="utf-8") for x in sys.argv]
c_argv = <char**>malloc(sizeof(char*) * len(args)) # TODO check failure?
for idx, s in enumerate(args):
c_argv[idx] = s
Expand Down Expand Up @@ -316,8 +322,8 @@ cdef class Model: # {{{
if not components:
self.save_all(fname)
return
fh = file(fname+".pym","w")
pfh = file(fname+".pyk","w")
fh = open(fname+".pym","w")
pfh = open(fname+".pyk","w")
cdef CModelSaver *saver = new CModelSaver(fname, self.thisptr)
for c in components:
self._save_one(c,saver,fh,pfh)
Expand Down Expand Up @@ -362,18 +368,18 @@ cdef class Model: # {{{
saveable.restore_components(items)
return saveable
else:
print "Huh?"
print("Huh?")
assert False,"unsupported type " + tp

cpdef load(self, string fname):
if not os.path.isfile(fname+".pym"):
self.load_all(fname)
return
with file(fname+".pym","r") as fh:
with open(fname+".pym","r") as fh:
types = fh.read().strip().split()

cdef CModelLoader *loader = new CModelLoader(fname, self.thisptr)
with file(fname+".pyk","r") as pfh:
with open(fname+".pyk","r") as pfh:
params = []
itypes = iter(types)
while True: # until iterator is done
Expand Down Expand Up @@ -589,11 +595,16 @@ cdef class Expression: #{{{
def __str__(self):
return "exprssion %s/%s" % (<int>self.vindex, self.cg_version)

def __getitem__(self, int i):
return pick(self, i)
# def __getitem__(self, int i):
# return pick(self, i)
def __getitem__(self, object index):
if isinstance(index, int):
return pick(self, index)

return pickrange(self, index[0], index[1])

def __getslice__(self, int i, int j):
return pickrange(self, i, j)
# def __getslice__(self, int i, int j):
# return pickrange(self, i, j)

cpdef scalar_value(self, recalculate=False):
if self.cg_version != _cg._cg_version: raise RuntimeError("Stale Expression (created before renewing the Computation Graph).")
Expand Down Expand Up @@ -885,7 +896,7 @@ cpdef Expression esum(list xs):
for x in xs:
ensure_freshness(x)
cvec.push_back(x.c())
#print >> sys.stderr, cvec.size()
#print(cvec.size(), file=sys.stderr)
return Expression.from_cexpr(x.cg_version, c_sum(cvec))

cpdef Expression average(list xs):
Expand Down
16 changes: 8 additions & 8 deletions python/dynet_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,9 @@ def print_graphviz(compact=False, show_dims=True, expression_names=None, lookup_
(nodes, birnn_collapse_to) = collapse_birnn_states(nodes, compact)
collapse_to.update(birnn_collapse_to)

print 'digraph G {'
print ' rankdir=BT;'
if not compact: print ' nodesep=.05;'
print('digraph G {')
print(' rankdir=BT;')
if not compact: print(' nodesep=.05;')

node_types = defaultdict(set)
for n in nodes:
Expand All @@ -982,7 +982,7 @@ def print_graphviz(compact=False, show_dims=True, expression_names=None, lookup_
'2_regular': '[shape=rect]',
'3_rnn_state': '[shape=rect, peripheries=2]',
}[node_type]
print ' node %s; ' % (style), ' '.join(node_types[node_type])
print(' node %s; ' % (style), ' '.join(node_types[node_type]))

# all_nodes = set(line.strip().split()[0] for line in node_def_lines)
for n in nodes:
Expand All @@ -995,9 +995,9 @@ def print_graphviz(compact=False, show_dims=True, expression_names=None, lookup_
label = '%s\\n%s' % (label, shape_str(n.input_dim))
if n.output_dim.invalid() or (n.input_dim is not None and n.input_dim.invalid()):
n.features += " [color=red,style=filled,fillcolor=red]"
print ' %s [label="%s"] %s;' % (n.name, label, n.features)
print(' %s [label="%s"] %s;' % (n.name, label, n.features))
for c in n.children:
print ' %s -> %s;' % (c, n.name)
print(' %s -> %s;' % (c, n.name))

rnn_states = [] # (name, rnn_name, state_idx)
rnn_state_re = re.compile("[^-]+-(.)-(\\d+)")
Expand All @@ -1016,6 +1016,6 @@ def print_graphviz(compact=False, show_dims=True, expression_names=None, lookup_
group_name_n = collapse_to.get(name_n, name_n)
edges.add((group_name_p, group_name_n))
for (name_p, name_n) in edges:
print ' %s -> %s [style=dotted];' % (name_p, name_n) # ,dir=both
print(' %s -> %s [style=dotted];' % (name_p, name_n)) # ,dir=both

print '}'
print('}')
7 changes: 4 additions & 3 deletions python/setup.py.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from setuptools import setup
from setuptools.extension import Extension
#from setuptools.extension import Extension
from Cython.Distutils.extension import Extension
from Cython.Distutils import build_ext


Expand All @@ -18,7 +19,7 @@ if platform.system() == 'Darwin':


ext_cpu = Extension(
"_dynet", # name of extension
"dynet", # name of extension
["dynet.pyx"], # filename of our Pyrex/Cython source
language="c++", # this causes Pyrex/Cython to create C++ source
include_dirs=["${PROJECT_SOURCE_DIR}", # this is the location of the main dynet directory.
Expand All @@ -32,7 +33,7 @@ ext_cpu = Extension(
)

ext_gpu = Extension(
"_gdynet", # name of extension
"gdynet", # name of extension
["gdynet.pyx"], # filename of our Pyrex/Cython source
language="c++", # this causes Pyrex/Cython to create C++ source
include_dirs=["${PROJECT_SOURCE_DIR}", # this is the location of the main dynet directory.
Expand Down

0 comments on commit 4efe677

Please sign in to comment.