Skip to content

Commit

Permalink
feat: made encoding and decodings hashable
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Jan 27, 2025
1 parent db74830 commit 0be032f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 23 deletions.
23 changes: 12 additions & 11 deletions src/qiboml/models/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from qibo import Circuit, gates
from qibo.backends import Backend, _check_backend
from qibo.config import raise_error
from qibo.hamiltonians import Hamiltonian
from qibo.hamiltonians import Hamiltonian, Z

from qiboml import ndarray

Expand All @@ -13,14 +13,15 @@
class QuantumDecoding:

nqubits: int
qubits: list[int] = None
qubits: tuple[int] = None
nshots: int = None
backend: Backend = None
_circuit: Circuit = None

def __post_init__(self):
if self.qubits is None:
self.qubits = list(range(self.nqubits))
self.qubits = (
tuple(range(self.nqubits)) if self.qubits is None else tuple(self.qubits)
)
self._circuit = Circuit(self.nqubits)
self.backend = _check_backend(self.backend)
self._circuit.add(gates.M(*self.qubits))
Expand All @@ -47,8 +48,10 @@ def analytic(self):
return True
return False

def __hash__(self) -> int:
return hash((self.qubits, self.nshots, self.backend))


@dataclass
class Probabilities(QuantumDecoding):
# TODO: collapse on ExpectationDecoding if not analytic

Expand All @@ -71,10 +74,7 @@ class Expectation(QuantumDecoding):

def __post_init__(self):
if self.observable is None:
raise_error(
RuntimeError,
"Please provide an observable for expectation value calculation.",
)
self.observable = Z(self.nqubits, dense=True, backend=self.backend)
super().__post_init__()

def __call__(self, x: Circuit) -> ndarray:
Expand All @@ -96,8 +96,10 @@ def set_backend(self, backend):
super().set_backend(backend)
self.observable.backend = backend

def __hash__(self) -> int:
return hash((self.qubits, self.nshots, self.backend, self.observable))


@dataclass
class State(QuantumDecoding):

def __call__(self, x: Circuit) -> ndarray:
Expand All @@ -115,7 +117,6 @@ def analytic(self):
return True


@dataclass
class Samples(QuantumDecoding):

def __post_init__(self):
Expand Down
14 changes: 9 additions & 5 deletions src/qiboml/models/encoding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional

import numpy as np
from qibo import Circuit, gates
Expand All @@ -12,14 +13,16 @@
class QuantumEncoding(ABC):

nqubits: int
qubits: list[int] = None
qubits: Optional[tuple[int]] = None
_circuit: Circuit = None

def __post_init__(
self,
):
if self.qubits is None:
self.qubits = list(range(self.nqubits))
self.qubits = (
tuple(range(self.nqubits)) if self.qubits is None else tuple(self.qubits)
)

self._circuit = Circuit(self.nqubits)

@abstractmethod
Expand All @@ -36,8 +39,10 @@ def circuit(
def differentiable(self):
return True

def __hash__(self) -> int:
return hash(self.qubits)


@dataclass
class PhaseEncoding(QuantumEncoding):

def __post_init__(
Expand All @@ -56,7 +61,6 @@ def __call__(self, x: ndarray) -> Circuit:
return self._circuit


@dataclass
class BinaryEncoding(QuantumEncoding):

def __call__(self, x: ndarray) -> Circuit:
Expand Down
10 changes: 6 additions & 4 deletions src/qiboml/operations/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,27 @@ def evaluate(
self,
"_jacobian",
partial(jax.jit, static_argnums=(1, 2, 3))(
jax.jacfwd(self._run, (0,) + self._argnums)
jax.jacfwd(self._run, (0,) + self._argnums),
),
)
setattr(
self,
"_jacobian_without_inputs",
partial(jax.jit, static_argnums=(1, 2, 3))(
jax.jacfwd(self._run, self._argnums)
jax.jacfwd(self._run, self._argnums),
),
)
parameters = backend.to_numpy(list(parameters))
parameters = self._jax.cast(parameters, parameters.dtype)
decoding.set_backend(self._jax)
if wrt_inputs:
gradients = self._jacobian(x, encoding, circuit, decoding, *parameters)
gradients = self._jacobian(
x, encoding, circuit, decoding, *parameters
) # pylint: disable=no-member
else:
gradients = (
self._jax.numpy.zeros((decoding.output_shape[-1], x.shape[-1])),
self._jacobian_without_inputs(
self._jacobian_without_inputs( # pylint: disable=no-member
x, encoding, circuit, decoding, *parameters
),
)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_models_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def test_expectation_layer(backend, nshots):
backend.set_seed(42)
rng = np.random.default_rng(42)
nqubits = 5
# test observable error
with pytest.raises(RuntimeError):
layer = dec.Expectation(nqubits, backend=backend)

c = random_clifford(nqubits, seed=rng, backend=backend)
observable = hamiltonians.SymbolicHamiltonian(
Expand Down

0 comments on commit 0be032f

Please sign in to comment.