Skip to content

Commit

Permalink
Update Modify
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Jan 13, 2025
1 parent 2ef7377 commit 4714af2
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 35 deletions.
12 changes: 10 additions & 2 deletions visitor/gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"github.com/itsubaki/q/quantum/gate"
)

// AddControlled returns a controlled-u gate with control bit.
// Controlled returns a controlled-u gate with control bit.
// u is a (2**n x 2**n) unitary matrix and returns a (2**n x 2**n) matrix.
func AddControlled(u matrix.Matrix, c []int) matrix.Matrix {
func Controlled(u matrix.Matrix, c []int) matrix.Matrix {
d, _ := u.Dimension()
n := number.Log2(d)
g := gate.I(n)
Expand All @@ -31,6 +31,14 @@ func AddControlled(u matrix.Matrix, c []int) matrix.Matrix {
return g
}

// NegControlled returns a controlled-u gate with control bit.
// u is a (2**n x 2**n) unitary matrix and returns a (2**n x 2**n) matrix.
func NegControlled(u matrix.Matrix, n int, c []int) matrix.Matrix {
x := gate.TensorProduct(gate.X(), n, c)
cu := Controlled(u, c)
return matrix.Apply(x, cu, x)
}

func Pow(u matrix.Matrix, p float64) matrix.Matrix {
// TODO: support float type
return matrix.ApplyN(u, int(p))
Expand Down
60 changes: 58 additions & 2 deletions visitor/gate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/itsubaki/qasm/visitor"
)

func TestAddControlled(t *testing.T) {
func TestControlled(t *testing.T) {
u := gate.U(rand.Float64(), rand.Float64(), rand.Float64())

cases := []struct {
Expand Down Expand Up @@ -70,7 +70,63 @@ func TestAddControlled(t *testing.T) {
}

for _, c := range cases {
got := visitor.AddControlled(c.in, []int{c.bit})
got := visitor.Controlled(c.in, []int{c.bit})
if !got.Equals(c.want) {
t.Fail()
}
}
}

func TestNegControlled(t *testing.T) {
cases := []struct {
in matrix.Matrix
want matrix.Matrix
n, bit int
}{
{
in: gate.TensorProduct(gate.X(), 2, []int{1}),
want: matrix.Apply(
gate.TensorProduct(gate.X(), 2, []int{0}),
gate.ControlledNot(2, []int{0}, 1),
gate.TensorProduct(gate.X(), 2, []int{0}),
),
n: 2,
bit: 0,
},
{
in: gate.TensorProduct(gate.X(), 2, []int{0}),
want: matrix.Apply(
gate.TensorProduct(gate.X(), 2, []int{1}),
gate.ControlledNot(2, []int{1}, 0),
gate.TensorProduct(gate.X(), 2, []int{1}),
),
n: 2,
bit: 1,
},
{
in: gate.ControlledNot(3, []int{0}, 2),
want: matrix.Apply(
gate.TensorProduct(gate.X(), 3, []int{1}),
gate.ControlledNot(3, []int{0, 1}, 2),
gate.TensorProduct(gate.X(), 3, []int{1}),
),
n: 3,
bit: 1,
},
{
in: gate.ControlledNot(3, []int{0}, 1),
want: matrix.Apply(
gate.TensorProduct(gate.X(), 3, []int{2}),
gate.ControlledNot(3, []int{0, 2}, 1),
gate.TensorProduct(gate.X(), 3, []int{2}),
),
n: 3,
bit: 2,
},
}

for _, c := range cases {
got := visitor.NegControlled(c.in, c.n, []int{c.bit})
if !got.Equals(c.want) {
t.Fail()
}
Expand Down
61 changes: 30 additions & 31 deletions visitor/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,35 +311,6 @@ func (v *Visitor) Params(xlist parser.IExpressionListContext) ([]float64, error)
return params, nil
}

func (v *Visitor) Modify(u matrix.Matrix, qargs [][]q.Qubit, modifier []parser.IGateModifierContext) (matrix.Matrix, error) {
for i, mod := range modifier {
switch {
case mod.CTRL() != nil:
u = AddControlled(u, q.Index(qargs[i]...))
case mod.NEGCTRL() != nil:
x := gate.TensorProduct(gate.X(), v.qsim.NumberOfBit(), q.Index(qargs[i]...))
u = AddControlled(u, q.Index(qargs[i]...))
u = matrix.Apply(x, u, x)
case mod.INV() != nil:
u = u.Dagger()
case mod.POW() != nil:
var p float64
switch n := v.Visit(mod).(type) {
case float64:
p = n
case int64:
p = float64(n)
default:
return nil, fmt.Errorf("pow=%v: %w", n, ErrUnexpected)
}

u = Pow(u, p)
}
}

return u, nil
}

func (v *Visitor) Builtin(ctx *parser.GateCallStatementContext) (matrix.Matrix, bool, error) {
if ctx.GPHASE() != nil {
params, err := v.Params(ctx.ExpressionList())
Expand Down Expand Up @@ -368,7 +339,7 @@ func (v *Visitor) Builtin(ctx *parser.GateCallStatementContext) (matrix.Matrix,
u = gate.TensorProduct(u, n, q.Index(flatten(qargs)...))

// modify
u, err = v.Modify(u, qargs, ctx.AllGateModifier())
u, err = v.Modify(u, n, qargs, ctx.AllGateModifier())
if err != nil {
return nil, false, fmt.Errorf("modify: %w", err)
}
Expand All @@ -379,6 +350,33 @@ func (v *Visitor) Builtin(ctx *parser.GateCallStatementContext) (matrix.Matrix,
}
}

func (v *Visitor) Modify(u matrix.Matrix, n int, qargs [][]q.Qubit, modifier []parser.IGateModifierContext) (matrix.Matrix, error) {
for i, mod := range modifier {
switch {
case mod.CTRL() != nil:
u = Controlled(u, q.Index(qargs[i]...))
case mod.NEGCTRL() != nil:
u = NegControlled(u, n, q.Index(qargs[i]...))
case mod.INV() != nil:
u = u.Dagger()
case mod.POW() != nil:
var p float64
switch n := v.Visit(mod).(type) {
case float64:
p = n
case int64:
p = float64(n)
default:
return nil, fmt.Errorf("pow=%v: %w", n, ErrUnexpected)
}

u = Pow(u, p)
}
}

return u, nil
}

func (v *Visitor) Defined(ctx *parser.GateCallStatementContext) (matrix.Matrix, error) {
id := v.Visit(ctx.Identifier()).(string)
g, ok := v.Environ.GetGate(id)
Expand Down Expand Up @@ -439,7 +437,8 @@ func (v *Visitor) Defined(ctx *parser.GateCallStatementContext) (matrix.Matrix,
u := matrix.Apply(list...)

// modify
u, err := v.Modify(u, qargs, ctx.AllGateModifier())
n := v.qsim.NumberOfBit()
u, err := v.Modify(u, n, qargs, ctx.AllGateModifier())
if err != nil {
return nil, fmt.Errorf("modify: %w", err)
}
Expand Down

0 comments on commit 4714af2

Please sign in to comment.