Skip to content

Commit

Permalink
[Fix] Device inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Jan 31, 2024
1 parent f892cb0 commit ddeac31
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 87 deletions.
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ authors = [
]
requires-python = ">=3.8,<3.12"
license = {text = "Apache 2.0"}
version = "1.0.3"
version = "1.0.4"
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
Expand All @@ -28,14 +28,10 @@ classifiers=[
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"numpy",
"torch",
]
dependencies = ["torch"]

[project.optional-dependencies]
dev = ["black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert",
"ipykernel", "matplotlib"]
dev = ["black", "pytest", "pytest-xdist", "pytest-cov", "flake8", "mypy", "pre-commit", "ruff", "nbconvert", "matplotlib"]

[tool.hatch.envs.tests]
features = [
Expand Down
57 changes: 0 additions & 57 deletions pyqtorch/abstract.py

This file was deleted.

36 changes: 19 additions & 17 deletions pyqtorch/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@

from typing import Any, Iterator

import torch
from torch import Tensor, device
from torch.nn import Module, ModuleList

from pyqtorch.abstract import AbstractOperator
from pyqtorch.utils import DiffMode, State, overlap, zero_state


class QuantumCircuit(torch.nn.Module):
def __init__(
self, n_qubits: int, operations: list[AbstractOperator], diff_mode: DiffMode = DiffMode.AD
):
class QuantumCircuit(Module):
def __init__(self, n_qubits: int, operations: list[Module], diff_mode: DiffMode = DiffMode.AD):
super().__init__()
self.n_qubits = n_qubits
self.operations = torch.nn.ModuleList(operations)
self.operations = ModuleList(operations)
self.diff_mode = diff_mode

def __mul__(self, other: AbstractOperator | QuantumCircuit) -> QuantumCircuit:
def __mul__(self, other: Module | QuantumCircuit) -> QuantumCircuit:
n_qubits = max(self.n_qubits, other.n_qubits)
if isinstance(other, QuantumCircuit):
return QuantumCircuit(n_qubits, self.operations.extend(other.operations))

elif isinstance(other, AbstractOperator):
elif isinstance(other, Module):
return QuantumCircuit(n_qubits, self.operations.append(other))

else:
Expand All @@ -43,22 +41,22 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(self.__key())

def run(self, state: State = None, values: dict[str, torch.Tensor] = {}) -> State:
def run(self, state: State = None, values: dict[str, Tensor] = {}) -> State:
if state is None:
state = self.init_state()
for op in self.operations:
state = op(state, values)
return state

def forward(self, state: State, values: dict[str, torch.Tensor] = {}) -> State:
def forward(self, state: State, values: dict[str, Tensor] = {}) -> State:
return self.run(state, values)

def expectation(
self,
values: dict[str, torch.Tensor],
values: dict[str, Tensor],
observable: QuantumCircuit,
state: State = None,
) -> torch.Tensor:
) -> Tensor:
if observable is None:
raise ValueError("Please provide an observable to compute expectation.")
if state is None:
Expand All @@ -74,15 +72,19 @@ def expectation(
)

@property
def _device(self) -> torch.device:
def _device(self) -> device:
try:
(_, buffer) = next(self.named_buffers())
return buffer.device
except StopIteration:
return torch.device("cpu")
return device("cpu")

def init_state(self, batch_size: int = 1) -> torch.Tensor:
def init_state(self, batch_size: int = 1) -> Tensor:
return zero_state(self.n_qubits, batch_size, device=self._device)

def reverse(self) -> QuantumCircuit:
return QuantumCircuit(self.n_qubits, torch.nn.ModuleList(list(reversed(self.operations))))
return QuantumCircuit(self.n_qubits, ModuleList(list(reversed(self.operations))))

def to(self, device: device) -> QuantumCircuit:
self.operations = ModuleList([op.to(device) for op in self.operations])
return self
25 changes: 21 additions & 4 deletions pyqtorch/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@

import torch

from pyqtorch.abstract import AbstractOperator
from pyqtorch.apply import apply_operator
from pyqtorch.matrices import OPERATIONS_DICT, _controlled, _dagger
from pyqtorch.utils import Operator, State, product_state


class Primitive(AbstractOperator):
def __init__(self, pauli: torch.Tensor, target: int):
super().__init__(target)
class Primitive(torch.nn.Module):
def __init__(self, pauli: torch.Tensor, target: int) -> None:
super().__init__()
self.target: int = target
self.qubit_support: Tuple[int, ...] = (target,)
self.n_qubits: int = max(self.qubit_support)
self.register_buffer("pauli", pauli)
self._param_type = None

def __key(self) -> tuple:
return self.qubit_support

def __eq__(self, other: object) -> bool:
if isinstance(other, type(self)):
return self.__key() == other.__key()
else:
return False

def __hash__(self) -> int:
return hash(self.qubit_support)

def extra_repr(self) -> str:
return f"qubit_support={self.qubit_support}"

@property
def param_type(self) -> None:
return self._param_type
Expand Down
4 changes: 2 additions & 2 deletions tests/test_digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_dagger_single_qubit() -> None:
if issubclass(cls, Parametric):
op = cls(target, param_name) # type: ignore[arg-type]
else:
op = cls(target) # type: ignore[misc]
op = cls(target) # type: ignore[assignment, call-arg]
values = {param_name: torch.rand(1)} if param_name == "theta" else torch.rand(1)
new_state = apply_operator(state, op.unitary(values), [target])
daggered_back = apply_operator(new_state, op.dagger(values), [target])
Expand All @@ -257,7 +257,7 @@ def test_dagger_nqubit() -> None:
op = cls(target - 1, target, param_name) # type: ignore[arg-type]
qubit_support = (target + 1, target)
else:
op = cls(target - 1, target) # type: ignore[misc]
op = cls(target - 1, target) # type: ignore[call-arg]
qubit_support = (target + 1, target)
values = {param_name: torch.rand(1)} if param_name == "theta" else torch.rand(1)
new_state = apply_operator(state, op.unitary(values), qubit_support)
Expand Down

0 comments on commit ddeac31

Please sign in to comment.