Skip to content

Commit

Permalink
inner prod in adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmoutinho committed Feb 9, 2024
1 parent a5639fb commit b103035
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyqtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Z,
)
from .utils import (
inner_prod,
is_normalized,
overlap,
random_state,
Expand Down
6 changes: 3 additions & 3 deletions pyqtorch/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyqtorch.apply import apply_operator
from pyqtorch.circuit import QuantumCircuit
from pyqtorch.parametric import Parametric
from pyqtorch.utils import overlap, param_dict
from pyqtorch.utils import inner_prod, param_dict


class AdjointExpectation(Function):
Expand All @@ -30,7 +30,7 @@ def forward(
ctx.out_state = circuit.run(state, values)
ctx.projected_state = observable.run(ctx.out_state, values)
ctx.save_for_backward(*param_values)
return overlap(ctx.out_state, ctx.projected_state)
return inner_prod(ctx.out_state, ctx.projected_state).real

@staticmethod
@torch.no_grad()
Expand All @@ -43,7 +43,7 @@ def backward(ctx: Any, grad_out: Tensor) -> tuple:
if isinstance(op, Parametric):
if values[op.param_name].requires_grad:
mu = apply_operator(ctx.out_state, op.jacobian(values), op.qubit_support)
grad = grad_out * 2 * overlap(ctx.projected_state, mu)
grad = grad_out * 2 * inner_prod(ctx.projected_state, mu).real
else:
grad = torch.zeros(1)

Expand Down

0 comments on commit b103035

Please sign in to comment.