Skip to content

Commit

Permalink
Add custom gate support to mpl backend (Qiskit#2070)
Browse files Browse the repository at this point in the history
* Add custom gate support to mpl backend

This commit adds support for custom gate types to the mpl backend. It
handles arbitrary sized gates and uses the node name as the gate name.
By default all the gates are wide since the names are of unknown length.
A future improvement would be to make the gate width dynamic based on
the name length.

Fixes Qiskit#1941

* Remove stray print

* Updates per review comments

* Fix lint

* Add dynamic spacing and sizing for gates

* Fix lint

* Add support for gates with non-contiguous qubits

* Add input qubit annotation to custom gates
  • Loading branch information
mtreinish authored and ajavadia committed Apr 11, 2019
1 parent 4c482ab commit f6cfde6
Showing 1 changed file with 109 additions and 8 deletions.
117 changes: 109 additions & 8 deletions qiskit/visualization/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,70 @@ def _registers(self, creg, qreg):
def ast(self):
return self._ast

def _custom_multiqubit_gate(self, xy, fc=None, wide=True, text=None,
subtext=None):
xpos = min([x[0] for x in xy])
ypos = min([y[1] for y in xy])
ypos_max = max([y[1] for y in xy])
if wide:
if subtext:
boxes_length = round(max([len(text), len(subtext)]) / 8) or 1
else:
boxes_length = round(len(text) / 8) or 1
wid = WID * 2.8 * boxes_length
else:
wid = WID
if fc:
_fc = fc
else:
_fc = self._style.gc
qubit_span = abs(ypos) - abs(ypos_max) + 1
height = HIG * (qubit_span + 1)
box = patches.Rectangle(
xy=(xpos - 0.5 * wid, ypos - .5 * HIG),
width=wid, height=height, fc=_fc, ec=self._style.lc,
linewidth=1.5, zorder=PORDER_GATE)
self.ax.add_patch(box)
# Annotate inputs
for bit, y in enumerate([x[1] for x in xy]):
self.ax.text(xpos - 0.45 * wid, y, str(bit), ha='left', va='center',
fontsize=self._style.fs, color=self._style.gt,
clip_on=True, zorder=PORDER_TEXT)

if text:
disp_text = text
if subtext:
self.ax.text(xpos, ypos + 0.15 * height, disp_text, ha='center',
va='center', fontsize=self._style.fs,
color=self._style.gt, clip_on=True,
zorder=PORDER_TEXT)
self.ax.text(xpos, ypos - 0.3 * height, subtext, ha='center',
va='center', fontsize=self._style.sfs,
color=self._style.sc, clip_on=True,
zorder=PORDER_TEXT)
else:
self.ax.text(xpos, ypos + .5 * (qubit_span - 1), disp_text,
ha='center',
va='center',
fontsize=self._style.fs,
color=self._style.gt,
clip_on=True,
zorder=PORDER_TEXT)

def _gate(self, xy, fc=None, wide=False, text=None, subtext=None):
xpos, ypos = xy

if wide:
wid = WID * 2.8
if subtext:
wid = WID * 2.8
else:
boxes_wide = round(len(text) / 10) or 1
wid = WID * 2.8 * boxes_wide
else:
wid = WID
if fc:
_fc = fc
elif text:
elif text and text in self._style.dispcol:
_fc = self._style.dispcol[text]
else:
_fc = self._style.gc
Expand All @@ -166,7 +220,10 @@ def _gate(self, xy, fc=None, wide=False, text=None, subtext=None):
self.ax.add_patch(box)

if text:
disp_text = "${}$".format(self._style.disptex[text])
if text in self._style.dispcol:
disp_text = "${}$".format(self._style.disptex[text])
else:
disp_text = text
if subtext:
self.ax.text(xpos, ypos + 0.15 * HIG, disp_text, ha='center',
va='center', fontsize=self._style.fs,
Expand Down Expand Up @@ -480,14 +537,35 @@ def _draw_ops(self, verbose=False):
layer_width = 1

for op in layer:

if op.name in _wide_gate:
layer_width = 2
if layer_width < 2:
layer_width = 2
# if custom gate with a longer than standard name determine
# width
elif op.name not in ['barrier', 'snapshot', 'load', 'save',
'noise', 'cswap', 'swap'] and len(
op.name) >= 4:
box_width = round(len(op.name) / 8)
# If more than 4 characters min width is 2
if box_width <= 1:
box_width = 2
if layer_width < box_width:
if box_width > 2:
layer_width = box_width * 2
else:
layer_width = 2

this_anc = prev_anc + 1

for op in layer:

_iswide = op.name in _wide_gate
if op.name not in ['barrier', 'snapshot', 'load', 'save',
'noise', 'cswap', 'swap'] and len(
op.name) >= 4:
_iswide = True

# get qreg index
q_idxs = []
for qarg in op.qargs:
Expand Down Expand Up @@ -592,6 +670,8 @@ def _draw_ops(self, verbose=False):
if op.name == 'cx':
self._ctrl_qubit(q_xy[0])
self._tgt_qubit(q_xy[1])
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# cz for latexmode
elif op.name == 'cz':
if self._style.latexmode:
Expand All @@ -601,6 +681,8 @@ def _draw_ops(self, verbose=False):
disp = op.name.replace('c', '')
self._ctrl_qubit(q_xy[0])
self._gate(q_xy[1], wide=_iswide, text=disp)
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# control gate
elif op.name in ['cy', 'ch', 'cu3', 'crz']:
disp = op.name.replace('c', '')
Expand All @@ -610,6 +692,8 @@ def _draw_ops(self, verbose=False):
subtext='{}'.format(param))
else:
self._gate(q_xy[1], wide=_iswide, text=disp)
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# cu1 for latexmode
elif op.name == 'cu1':
disp = op.name.replace('c', '')
Expand All @@ -620,12 +704,18 @@ def _draw_ops(self, verbose=False):
else:
self._gate(q_xy[1], wide=_iswide, text=disp,
subtext='{}'.format(param))
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# swap gate
elif op.name == 'swap':
self._swap(q_xy[0])
self._swap(q_xy[1])
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# Custom gate
else:
self._custom_multiqubit_gate(q_xy, wide=_iswide,
text=op.name)
#
# draw multi-qubit gates (n=3)
#
Expand All @@ -635,13 +725,24 @@ def _draw_ops(self, verbose=False):
self._ctrl_qubit(q_xy[0])
self._swap(q_xy[1])
self._swap(q_xy[2])
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# ccx gate
elif op.name == 'ccx':
self._ctrl_qubit(q_xy[0])
self._ctrl_qubit(q_xy[1])
self._tgt_qubit(q_xy[2])
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# add qubit-qubit wiring
self._line(qreg_b, qreg_t)
# custom gate
else:
self._custom_multiqubit_gate(q_xy, wide=_iswide,
text=op.name)

# draw custom multi-qubit gate
elif len(q_xy) > 3:
self._custom_multiqubit_gate(q_xy, wide=_iswide,
text=op.name)
else:
logger.critical('Invalid gate %s', op)
raise exceptions.VisualizationError('invalid gate {}'.format(op))
Expand Down

0 comments on commit f6cfde6

Please sign in to comment.