Skip to content

Commit

Permalink
Support for passing np.arrays as cython args
Browse files Browse the repository at this point in the history
  • Loading branch information
heeres committed May 29, 2014
1 parent d8bfe49 commit 923bf99
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 30 deletions.
35 changes: 17 additions & 18 deletions qutip/cy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ def dedent(self):
raise SyntaxError("Error in code generator")
self.level -= 1

def _get_arg_str(self, args):
if len(args) == 0:
return ''

ret = ''
for name, value in self.args.iteritems():
if isinstance(value, np.ndarray):
ret += ", np.ndarray[np.%s_t, ndim=%d] %s" % \
(value.dtype.name, len(value.shape), name)
else:
kind = type(value).__name__
ret += ", np." + kind + "_t " + name
return ret

def ODE_func_header(self):
"""Creates function header for time-dependent ODE RHS."""
func_name = "def cy_td_ode_rhs("
Expand All @@ -142,12 +156,7 @@ def ODE_func_header(self):
input_vars += (", np.ndarray[CTYPE_t, ndim=1] data" + str(k) +
", np.ndarray[int, ndim=1] idx" + str(k) +
", np.ndarray[int, ndim=1] ptr" + str(k))
if self.args:
td_consts = list(self.args.items())
td_len = len(td_consts)
for jj in range(td_len):
kind = type(td_consts[jj][1]).__name__
input_vars += ", np." + kind + "_t " + td_consts[jj][0]
input_vars += self._get_arg_str(self.args)
func_end = "):"
return [func_name + input_vars + func_end]

Expand All @@ -160,12 +169,7 @@ def col_spmv_header(self):
input_vars = ("int which, double t, np.ndarray[CTYPE_t, ndim=1] " +
"data, np.ndarray[int] idx,np.ndarray[int] " +
"ptr,np.ndarray[CTYPE_t, ndim=1] vec")
if len(self.args) > 0:
td_consts = list(self.args.items())
td_len = len(td_consts)
for jj in range(td_len):
kind = type(td_consts[jj][1]).__name__
input_vars += ", np." + kind + " " + td_consts[jj][0]
input_vars += self._get_arg_str(self.args)
func_end = "):"
return [func_name + input_vars + func_end]

Expand All @@ -178,12 +182,7 @@ def col_expect_header(self):
input_vars = ("int which, double t, np.ndarray[CTYPE_t, ndim=1] " +
"data, np.ndarray[int] idx,np.ndarray[int] " +
"ptr,np.ndarray[CTYPE_t, ndim=1] vec")
if len(self.args) > 0:
td_consts = list(self.args.items())
td_len = len(td_consts)
for jj in range(td_len):
kind = type(td_consts[jj][1]).__name__
input_vars += ", np." + kind + "_t" + " " + td_consts[jj][0]
input_vars += self._get_arg_str(self.args)
func_end = "):"
return [func_name + input_vars + func_end]

Expand Down
17 changes: 11 additions & 6 deletions qutip/mesolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,11 @@ def _mesolve_list_str_td(H_list, rho0, tlist, c_list, e_ops, args, opt,
for k in range(n_L_terms):
string_list.append("Ldata[%d], Linds[%d], Lptrs[%d]" % (k, k, k))
for name, value in args.items():
string_list.append(str(value))
if isinstance(value, np.ndarray):
globals()['var_%s'%name] = value
string_list.append('var_%s'%name)
else:
string_list.append(str(value))
parameter_string = ",".join(string_list)

#
Expand Down Expand Up @@ -721,11 +725,12 @@ def _mesolve_list_td(H_func, rho0, tlist, c_op_list, e_ops, args, opt,
string += ("Ldata[%d], Linds[%d], Lptrs[%d]," % (k, k, k))

if args:
td_consts = args.items()
for elem in td_consts:
string += str(elem[1])
if elem != td_consts[-1]:
string += (",")
for name, value in args.items():
if isinstance(value, np.ndarray):
globals()['var_%s'%name] = value
string += 'var_%s,'%name
else:
string += str(value) + ','

# run code generator
if not opt.rhs_reuse or odeconfig.tdfunc is None:
Expand Down
17 changes: 11 additions & 6 deletions qutip/sesolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,11 @@ def _sesolve_list_str_td(H_list, psi0, tlist, e_ops, args, opt,
for k in range(n_L_terms):
string_list.append("Ldata[%d], Linds[%d], Lptrs[%d]" % (k, k, k))
for name, value in args.items():
string_list.append(str(value))
if isinstance(value, np.ndarray):
globals()['var_%s'%name] = value
string_list.append('var_%s'%name)
else:
string_list.append(str(value))
parameter_string = ",".join(string_list)

#
Expand Down Expand Up @@ -454,11 +458,12 @@ def _sesolve_list_td(H_func, psi0, tlist, e_ops, args, opt, progress_bar):
"],Hptrs[" + str(k) + "],")

if args:
td_consts = args.items()
for elem in td_consts:
string += str(elem[1])
if elem != td_consts[-1]:
string += (",")
for name, value in args.items():
if isinstance(value, np.ndarray):
globals()['var_%s'%name] = value
string += 'var_%s,'%name
else:
string += str(value) + ','

# run code generator
if not opt.rhs_reuse or odeconfig.tdfunc is None:
Expand Down

0 comments on commit 923bf99

Please sign in to comment.