Skip to content

Commit

Permalink
Add IR to Jet (#11)
Browse files Browse the repository at this point in the history
* add ir

* move tests

* run black

* add requirements

* update black command

* update changelog

* rename IR to XIR

* update changelog (again)

* run isort

* run isort (again)

* update docstrings

* Apply suggestions from code review

Co-authored-by: Josh Izaac <[email protected]>
Co-authored-by: Mikhail Andrenkov <[email protected]>

Co-authored-by: Josh Izaac <[email protected]>
Co-authored-by: Mikhail Andrenkov <[email protected]>
  • Loading branch information
3 people authored May 25, 2021
1 parent 45c6020 commit c62d635
Show file tree
Hide file tree
Showing 17 changed files with 2,041 additions and 6 deletions.
4 changes: 3 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

* Running CMake with `-DBUILD_PYTHON=ON` now generates Python bindings within a `jet` package. [(#1)](https://github.com/XanaduAI/jet/pull/1)

* A new intermediate representation (IR) is added, including a parser, IR representation program, and a Strawberry Fields interface. [(#11)](https://github.com/XanaduAI/jet/pull/11)

### Improvements

* Exceptions are now favoured in place of `std::terminate` with `Exception` being the new base type for all exceptions thrown by Jet. [(#3)](https://github.com/XanaduAI/jet/pull/3)
Expand All @@ -32,7 +34,7 @@

This release contains contributions from (in alphabetical order):

[Mikhail Andrenkov](https://github.com/Mandrenkov), [Jack Brown](https://github.com/brownj85), [Lee J. O'Riordan](https://github.com/mlxd), [Trevor Vincent](https://github.com/trevor-vincent).
[Mikhail Andrenkov](https://github.com/Mandrenkov), [Jack Brown](https://github.com/brownj85), [Theodor Isacsson](https://github.com/thisac), [Lee J. O'Riordan](https://github.com/mlxd), [Trevor Vincent](https://github.com/trevor-vincent).

## Release 0.1.0 (current release)

Expand Down
4 changes: 2 additions & 2 deletions python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ setup: $(.VENV_DIR)/requirements_test.txt.touch
.PHONY: format
format:
ifdef check
$(.VENV_BIN)/black --check tests && $(.VENV_BIN)/isort --profile black --check-only tests
$(.VENV_BIN)/black -l 100 --check tests && $(.VENV_BIN)/isort --profile black --check-only tests
else
$(.VENV_BIN)/black tests && $(.VENV_BIN)/isort --profile black tests
$(.VENV_BIN)/black -l 100 tests && $(.VENV_BIN)/isort --profile black tests
endif

.PHONY: test
Expand Down
2 changes: 2 additions & 0 deletions python/requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
black
isort>5
lark
pytest>=5,<6
strawberryfields
4 changes: 1 addition & 3 deletions python/tests/test_tensor_network_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import pytest


@pytest.mark.parametrize(
"TensorNetworkFile", [jet.TensorNetworkFile32, jet.TensorNetworkFile64]
)
@pytest.mark.parametrize("TensorNetworkFile", [jet.TensorNetworkFile32, jet.TensorNetworkFile64])
def test_tensor_network_file(TensorNetworkFile):
"""Tests that a tensor network file can be constructed."""
tnf = TensorNetworkFile()
Expand Down
93 changes: 93 additions & 0 deletions python/tests/xir/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Integration tests for the IR"""

import pytest

from xir import XIRTransformer, xir_parser
from xir.program import XIRProgram
from xir.utils import is_equal


def parse_script(circuit: str) -> XIRProgram:
"""Parse and transform a circuit XIR script and return an XIRProgram"""
tree = xir_parser.parse(circuit)
return XIRTransformer().transform(tree)


photonics_script = """
gate Sgate, 2, 1;
gate BSgate, 2, 2;
gate Rgate, 1, 1;
output MeasureHomodyne;
Sgate(0.7, 0) | [1];
BSgate(0.1, 0.0) | [0, 1];
Rgate(0.2) | [1];
MeasureHomodyne(phi: 3) | [0];
"""

photonics_script_no_decl = """
use xstd;
Sgate(0.7, 0) | [1];
BSgate(0.1, 0.0) | [0, 1];
Rgate(0.2) | [1];
MeasureHomodyne(phi: 3) | [0];
"""

qubit_script = """
// this file works with the current parser
use xstd;
gate h(a)[0, 1]:
rz(-2.3932854391951004) | [0];
rz(a) | [1];
// rz(pi / sin(3 * 4 / 2 - 2)) | [a, 2];
end;
operator o(a):
0.7 * sin(a), X[0] @ Z[1];
-1.6, X[0];
2.45, Y[0] @ X[1];
end;
g_one(pi) | [0, 1];
g_two | [2];
g_three(1, 3.3) | [2];
// The circuit and statistics
ry(1.23) | [0];
rot(0.1, 0.2, 0.3) | [1];
h(0.2) | [0, 1, 2];
sample(observable: o(0.2), shots: 1000) | [0, 1];
"""


class TestParser:
"""Integration tests for parsing, and serializing, XIR scripts"""

@pytest.mark.parametrize("circuit", [qubit_script, photonics_script, photonics_script_no_decl])
def test_parse_and_serialize(self, circuit):
"""Test parsing and serializing an XIR script.
Tests parsing, serializing as well as the ``is_equal`` utils function.
"""
irprog = parse_script(circuit)
res = irprog.serialize()
assert is_equal(res, circuit)
193 changes: 193 additions & 0 deletions python/tests/xir/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the interfaces module"""

from decimal import Decimal
from typing import List, Tuple

import pytest
import strawberryfields as sf
from strawberryfields import ops

from xir.interfaces.strawberryfields_io import to_program, to_xir
from xir.program import GateDeclaration, Statement, XIRProgram


def create_xir_prog(
data: List[Tuple],
external_libs: List[str] = None,
include_decl: bool = True,
version: str = None,
) -> XIRProgram:
"""Create an XIRProgram object used for testing"""
# if no version number is passed, use the default one (by not specifying it)
if version is None:
irprog = XIRProgram()
else:
irprog = XIRProgram(version=version)

# add the statements to the program
stmts = [Statement(n, p, w) for n, p, w in data]
irprog._statements.extend(stmts)

# if declaration should be included, add them to the program
if include_decl:
declarations = [GateDeclaration(n, len(p), len(w)) for n, p, w in data]
irprog._declarations["gate"].extend(declarations)

# if any external libraries/files are included, add them to the program
if external_libs is not None:
irprog._include.extend(external_libs)

return irprog


def create_sf_prog(num_of_wires: int, ops: List[Tuple]):
"""Create a Strawberry Fields program"""
prog = sf.Program(num_of_wires)

with prog.context as q:
for gate, params, wires in ops:
regrefs = [q[w] for w in wires]
if len(params) == 0:
gate | regrefs
else:
gate(*params) | regrefs

return prog


class TestXIRToStrawberryFields:
"""Unit tests for the XIR to Strawberry Fields conversion"""

def test_empty_irprogram(self):
"""Test that converting an empty XIR program raises an error"""
irprog = create_xir_prog(data=[])
with pytest.raises(ValueError, match="XIR program is empty and cannot be transformed"):
to_program(irprog)

def test_gate_not_defined(self):
"""Test unknown gate raises error"""
circuit_data = [
("not_a_real_gate", [Decimal("0.42")], (1, 2, 3)),
]
irprog = create_xir_prog(data=circuit_data)

with pytest.raises(NameError, match="operation 'not_a_real_gate' not defined"):
to_program(irprog)

def test_gates_no_args(self):
"""Test that gates without arguments work"""
circuit_data = [
("Vac", [], (0,)),
]
irprog = create_xir_prog(data=circuit_data)

sfprog = to_program(irprog)

assert len(sfprog) == 1
assert sfprog.circuit[0].op.__class__.__name__ == "Vacuum"
assert sfprog.circuit[0].reg[0].ind == 0

def test_gates_with_args(self):
"""Test that gates with arguments work"""

circuit_data = [
("Sgate", [Decimal("0.1"), Decimal("0.2")], (0,)),
("Sgate", [Decimal("0.3")], (1,)),
]
irprog = create_xir_prog(data=circuit_data)
sfprog = to_program(irprog)

assert len(sfprog) == 2
assert sfprog.circuit[0].op.__class__.__name__ == "Sgate"
assert sfprog.circuit[0].op.p[0] == 0.1
assert sfprog.circuit[0].op.p[1] == 0.2
assert sfprog.circuit[0].reg[0].ind == 0

assert sfprog.circuit[1].op.__class__.__name__ == "Sgate"
assert sfprog.circuit[1].op.p[0] == 0.3
assert sfprog.circuit[1].op.p[1] == 0.0 # default value
assert sfprog.circuit[1].reg[0].ind == 1


class TestStrawberryFieldsToXIR:
"""Unit tests for the XIR to Strawberry Fields conversion"""

def test_empty_sfprogram(self):
"""Test that converting from an empty SF program works"""
sfprog = create_sf_prog(num_of_wires=2, ops=[])
irprog = to_xir(sfprog)

assert irprog.version == "0.1.0"
assert irprog.statements == []
assert irprog.include == []
assert irprog.statements == []
assert irprog.declarations == {"gate": [], "func": [], "output": [], "operator": []}

assert irprog.gates == dict()
assert irprog.operators == dict()
assert irprog.variables == set()
assert irprog._called_ops == set()

@pytest.mark.parametrize("add_decl", [True, False])
def test_gates_no_args(self, add_decl):
"""Test unknown gate raises error"""
circuit_data = [
(ops.Vac, [], (0,)),
]

sfprog = create_sf_prog(num_of_wires=2, ops=circuit_data)
irprog = to_xir(sfprog, add_decl=add_decl)

assert irprog.statements[0].name == "Vacuum"
assert irprog.statements[0].params == []
assert irprog.statements[0].wires == (0,)

if add_decl:
assert len(irprog.declarations["gate"]) == 1
assert irprog.declarations["gate"][0].name == "Vacuum"
assert irprog.declarations["gate"][0].num_params == 0
assert irprog.declarations["gate"][0].num_wires == 1
else:
assert irprog.declarations["gate"] == []

@pytest.mark.parametrize("add_decl", [True, False])
def test_gates_with_args(self, add_decl):
"""Test that gates with arguments work"""

circuit_data = [
(ops.Sgate, [0.1, 0.2], (0,)),
(ops.Sgate, [Decimal("0.3")], (1,)),
]

sfprog = create_sf_prog(num_of_wires=2, ops=circuit_data)
irprog = to_xir(sfprog, add_decl=add_decl)

assert irprog.statements[0].name == "Sgate"
assert irprog.statements[0].params == [0.1, 0.2]
assert irprog.statements[0].wires == (0,)

assert irprog.statements[1].name == "Sgate"
assert irprog.statements[1].params == [Decimal("0.3"), 0.0]
assert irprog.statements[1].wires == (1,)

if add_decl:
assert len(irprog.declarations["gate"]) == 1
assert irprog.declarations["gate"][0].name == "Sgate"
assert irprog.declarations["gate"][0].num_params == 2
assert irprog.declarations["gate"][0].num_wires == 1
else:
assert irprog.declarations["gate"] == []
Loading

0 comments on commit c62d635

Please sign in to comment.