From 1ea79613bb94f658c0312291436018e0caa2415b Mon Sep 17 00:00:00 2001 From: itsubaki <1759459+itsubaki@users.noreply.github.com> Date: Thu, 16 Jan 2025 21:34:59 +0900 Subject: [PATCH] Update some files --- Makefile | 2 +- cmd/parse/main.go | 11 --- main.go | 42 ++++++++ visitor/{gate.go => modifier.go} | 4 +- visitor/{gate_test.go => modifier_test.go} | 12 +-- visitor/visitor.go | 107 ++++++++++----------- visitor/visitor_test.go | 1 + 7 files changed, 104 insertions(+), 75 deletions(-) create mode 100644 main.go rename visitor/{gate.go => modifier.go} (90%) rename visitor/{gate_test.go => modifier_test.go} (94%) diff --git a/Makefile b/Makefile index cdb621e..eda2964 100644 --- a/Makefile +++ b/Makefile @@ -16,4 +16,4 @@ parse: go run cmd/parse/main.go < _testdata/bell.qasm test: - go test -v -cover $(shell go list ./... | grep -v /cmd | grep -v /gen) -v -coverprofile=coverage.txt -covermode=atomic + go test -v -cover $(shell go list ./... | grep -v /cmd | grep -v /gen | grep -v -E "qasm$$") -v -coverprofile=coverage.txt -covermode=atomic diff --git a/cmd/parse/main.go b/cmd/parse/main.go index 9c10baf..6556f08 100644 --- a/cmd/parse/main.go +++ b/cmd/parse/main.go @@ -5,10 +5,8 @@ import ( "os" "github.com/antlr4-go/antlr/v4" - "github.com/itsubaki/q" "github.com/itsubaki/qasm/gen/parser" "github.com/itsubaki/qasm/io" - "github.com/itsubaki/qasm/visitor" ) func main() { @@ -19,13 +17,4 @@ func main() { tree := p.Program() fmt.Println(tree.ToStringTree(nil, p)) - - qsim := q.New() - env := visitor.NewEnviron() - v := visitor.New(qsim, env) - - switch ret := v.Visit(tree).(type) { - case error: - panic(ret) - } } diff --git a/main.go b/main.go new file mode 100644 index 0000000..e10334b --- /dev/null +++ b/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/antlr4-go/antlr/v4" + "github.com/itsubaki/q" + "github.com/itsubaki/qasm/gen/parser" + "github.com/itsubaki/qasm/visitor" +) + +func main() { + var filepath string + flag.StringVar(&filepath, "f", "", "filepath") + flag.Parse() + + if filepath == "" { + fmt.Printf("Usage: %s -f filepath\n", os.Args[0]) + return + } + + text, err := os.ReadFile(filepath) + if err != nil { + panic(err) + } + + lexer := parser.Newqasm3Lexer(antlr.NewInputStream(string(text))) + p := parser.Newqasm3Parser(antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)) + + qsim := q.New() + v := visitor.New(qsim, visitor.NewEnviron()) + + if err := v.Visit(p.Program()); err != nil { + panic(err) + } + + for _, s := range qsim.State() { + fmt.Println(s) + } +} diff --git a/visitor/gate.go b/visitor/modifier.go similarity index 90% rename from visitor/gate.go rename to visitor/modifier.go index 5e91e68..4e08f5e 100644 --- a/visitor/gate.go +++ b/visitor/modifier.go @@ -33,7 +33,9 @@ func Controlled(u matrix.Matrix, c []int) matrix.Matrix { // NegControlled returns a controlled-u gate with control bit. // u is a (2**n x 2**n) unitary matrix and returns a (2**n x 2**n) matrix. -func NegControlled(u matrix.Matrix, n int, c []int) matrix.Matrix { +func NegControlled(u matrix.Matrix, c []int) matrix.Matrix { + d, _ := u.Dimension() + n := number.Log2(d) x := gate.TensorProduct(gate.X(), n, c) cu := Controlled(u, c) return matrix.Apply(x, cu, x) diff --git a/visitor/gate_test.go b/visitor/modifier_test.go similarity index 94% rename from visitor/gate_test.go rename to visitor/modifier_test.go index eaeae7d..31c1b68 100644 --- a/visitor/gate_test.go +++ b/visitor/modifier_test.go @@ -79,9 +79,9 @@ func TestControlled(t *testing.T) { func TestNegControlled(t *testing.T) { cases := []struct { - in matrix.Matrix - want matrix.Matrix - n, bit int + in matrix.Matrix + want matrix.Matrix + bit int }{ { in: gate.TensorProduct(gate.X(), 2, []int{1}), @@ -90,7 +90,6 @@ func TestNegControlled(t *testing.T) { gate.ControlledNot(2, []int{0}, 1), gate.TensorProduct(gate.X(), 2, []int{0}), ), - n: 2, bit: 0, }, { @@ -100,7 +99,6 @@ func TestNegControlled(t *testing.T) { gate.ControlledNot(2, []int{1}, 0), gate.TensorProduct(gate.X(), 2, []int{1}), ), - n: 2, bit: 1, }, { @@ -110,7 +108,6 @@ func TestNegControlled(t *testing.T) { gate.ControlledNot(3, []int{0, 1}, 2), gate.TensorProduct(gate.X(), 3, []int{1}), ), - n: 3, bit: 1, }, { @@ -120,13 +117,12 @@ func TestNegControlled(t *testing.T) { gate.ControlledNot(3, []int{0, 2}, 1), gate.TensorProduct(gate.X(), 3, []int{2}), ), - n: 3, bit: 2, }, } for _, c := range cases { - got := visitor.NegControlled(c.in, c.n, []int{c.bit}) + got := visitor.NegControlled(c.in, []int{c.bit}) if !got.Equals(c.want) { t.Fail() } diff --git a/visitor/visitor.go b/visitor/visitor.go index de3fd23..db8c353 100644 --- a/visitor/visitor.go +++ b/visitor/visitor.go @@ -28,6 +28,11 @@ var ( ErrNotImplemented = errors.New("not implemented") ) +type Visitor struct { + qsim *q.Q + env *Environ +} + func New(qsim *q.Q, env *Environ) *Visitor { return &Visitor{ qsim, @@ -35,13 +40,8 @@ func New(qsim *q.Q, env *Environ) *Visitor { } } -type Visitor struct { - qsim *q.Q - Environ *Environ -} - func (v *Visitor) Enclosed() *Visitor { - return New(v.qsim, v.Environ.NewEnclosed()) + return New(v.qsim, v.env.NewEnclosed()) } func (v *Visitor) Visit(tree antlr.ParseTree) interface{} { @@ -70,7 +70,7 @@ func (v *Visitor) VisitChildren(node antlr.RuleNode) interface{} { func (v *Visitor) VisitProgram(ctx *parser.ProgramContext) interface{} { if ctx.Version() != nil { - v.Environ.Version = v.Visit(ctx.Version()).(string) + v.env.Version = v.Visit(ctx.Version()).(string) } for _, s := range ctx.AllStatementOrScope() { @@ -202,7 +202,7 @@ func (v *Visitor) VisitForStatement(ctx *parser.ForStatementContext) interface{} enclosed := v.Enclosed() for i := rx[0]; i < rx[1]; i++ { - enclosed.Environ.Variable[id] = i + enclosed.env.Variable[id] = i result := enclosed.Visit(ctx.StatementOrScope()) if contains(result, Break) { @@ -265,7 +265,7 @@ func (v *Visitor) VisitReturnStatement(ctx *parser.ReturnStatementContext) inter func (v *Visitor) VisitGateStatement(ctx *parser.GateStatementContext) interface{} { name := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetGate(name); ok { + if _, ok := v.env.GetGate(name); ok { return fmt.Errorf("identifier=%s: %w", name, ErrAlreadyDeclared) } @@ -285,7 +285,7 @@ func (v *Visitor) VisitGateStatement(ctx *parser.GateStatementContext) interface body = append(body, s.Statement().GateCallStatement().(*parser.GateCallStatementContext)) } - v.Environ.Gate[name] = Gate{ + v.env.Gate[name] = Gate{ Name: name, Params: params, QArgs: qargs, @@ -339,7 +339,7 @@ func (v *Visitor) Builtin(ctx *parser.GateCallStatementContext) (matrix.Matrix, u = gate.TensorProduct(u, n, q.Index(flatten(qargs)...)) // modify - u, err = v.Modify(u, n, qargs, ctx.AllGateModifier()) + u, err = v.Modify(u, qargs, ctx.AllGateModifier()) if err != nil { return nil, false, fmt.Errorf("modify: %w", err) } @@ -350,13 +350,13 @@ func (v *Visitor) Builtin(ctx *parser.GateCallStatementContext) (matrix.Matrix, } } -func (v *Visitor) Modify(u matrix.Matrix, n int, qargs [][]q.Qubit, modifier []parser.IGateModifierContext) (matrix.Matrix, error) { +func (v *Visitor) Modify(u matrix.Matrix, qargs [][]q.Qubit, modifier []parser.IGateModifierContext) (matrix.Matrix, error) { for i, mod := range modifier { switch { case mod.CTRL() != nil: u = Controlled(u, q.Index(qargs[i]...)) case mod.NEGCTRL() != nil: - u = NegControlled(u, n, q.Index(qargs[i]...)) + u = NegControlled(u, q.Index(qargs[i]...)) case mod.INV() != nil: u = u.Dagger() case mod.POW() != nil: @@ -379,7 +379,7 @@ func (v *Visitor) Modify(u matrix.Matrix, n int, qargs [][]q.Qubit, modifier []p func (v *Visitor) Defined(ctx *parser.GateCallStatementContext) (matrix.Matrix, error) { id := v.Visit(ctx.Identifier()).(string) - g, ok := v.Environ.GetGate(id) + g, ok := v.env.GetGate(id) if !ok { return nil, fmt.Errorf("idenfitier=%s: %w", id, ErrGateNotFound) } @@ -393,7 +393,7 @@ func (v *Visitor) Defined(ctx *parser.GateCallStatementContext) (matrix.Matrix, } for i, p := range g.Params { - enclosed.Environ.Variable[p] = params[i] + enclosed.env.Variable[p] = params[i] } } @@ -411,7 +411,7 @@ func (v *Visitor) Defined(ctx *parser.GateCallStatementContext) (matrix.Matrix, qargs = v.Visit(ctx.GateOperandList()).([][]q.Qubit) for i, id := range g.QArgs { - enclosed.Environ.Qubit[id] = qargs[i+shift] + enclosed.env.Qubit[id] = qargs[i+shift] } } @@ -437,8 +437,7 @@ func (v *Visitor) Defined(ctx *parser.GateCallStatementContext) (matrix.Matrix, u := matrix.Apply(list...) // modify - n := v.qsim.NumberOfBit() - u, err := v.Modify(u, n, qargs, ctx.AllGateModifier()) + u, err := v.Modify(u, qargs, ctx.AllGateModifier()) if err != nil { return nil, fmt.Errorf("modify: %w", err) } @@ -470,7 +469,7 @@ func (v *Visitor) MeasureAssignment(identifier parser.IIndexedIdentifierContext, } operand := v.Visit(identifier.Identifier()).(string) - bits, ok := v.Environ.GetClassicalBit(operand) + bits, ok := v.env.GetClassicalBit(operand) if !ok { return fmt.Errorf("operand=%s: %w", operand, ErrClassicalBitNotFound) } @@ -498,7 +497,7 @@ func (v *Visitor) VisitAssignmentStatement(ctx *parser.AssignmentStatementContex } id := v.Visit(ctx.IndexedIdentifier().Identifier()).(string) - v.Environ.SetVariable(id, v.Visit(ctx.Expression())) + v.env.SetVariable(id, v.Visit(ctx.Expression())) return nil } @@ -510,22 +509,22 @@ func (v *Visitor) VisitResetStatement(ctx *parser.ResetStatementContext) interfa func (v *Visitor) VisitConstDeclarationStatement(ctx *parser.ConstDeclarationStatementContext) interface{} { id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetConst(id); ok { + if _, ok := v.env.GetConst(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } - v.Environ.Const[id] = v.Visit(ctx.DeclarationExpression()) + v.env.Const[id] = v.Visit(ctx.DeclarationExpression()) return nil } func (v *Visitor) VisitQuantumDeclarationStatement(ctx *parser.QuantumDeclarationStatementContext) interface{} { id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetQubit(id); ok { + if _, ok := v.env.GetQubit(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } size := v.Visit(ctx.QubitType()).(int64) - v.Environ.Qubit[id] = v.qsim.ZeroWith(int(size)) + v.env.Qubit[id] = v.qsim.ZeroWith(int(size)) return nil } @@ -533,70 +532,70 @@ func (v *Visitor) VisitClassicalDeclarationStatement(ctx *parser.ClassicalDeclar switch { case ctx.ScalarType().BIT() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetClassicalBit(id); ok { + if _, ok := v.env.GetClassicalBit(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } if ctx.DeclarationExpression() != nil { bits := v.Visit(ctx.DeclarationExpression()).([]int64) - v.Environ.ClassicalBit[id] = bits + v.env.ClassicalBit[id] = bits return nil } size := v.Visit(ctx.ScalarType()).(int64) - v.Environ.ClassicalBit[id] = make([]int64, int(size)) + v.env.ClassicalBit[id] = make([]int64, int(size)) return nil case ctx.ScalarType().FLOAT() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetVariable(id); ok { + if _, ok := v.env.GetVariable(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } if ctx.DeclarationExpression() != nil { - v.Environ.Variable[id] = v.Visit(ctx.DeclarationExpression()) + v.env.Variable[id] = v.Visit(ctx.DeclarationExpression()) return nil } - v.Environ.Variable[id] = float64(0) + v.env.Variable[id] = float64(0) return nil case ctx.ScalarType().INT() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetVariable(id); ok { + if _, ok := v.env.GetVariable(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } if ctx.DeclarationExpression() != nil { - v.Environ.Variable[id] = v.Visit(ctx.DeclarationExpression()) + v.env.Variable[id] = v.Visit(ctx.DeclarationExpression()) return nil } - v.Environ.Variable[id] = int(0) + v.env.Variable[id] = int(0) return nil case ctx.ScalarType().UINT() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetVariable(id); ok { + if _, ok := v.env.GetVariable(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } if ctx.DeclarationExpression() != nil { - v.Environ.Variable[id] = v.Visit(ctx.DeclarationExpression()) + v.env.Variable[id] = v.Visit(ctx.DeclarationExpression()) return nil } - v.Environ.Variable[id] = uint(0) + v.env.Variable[id] = uint(0) return nil case ctx.ScalarType().BOOL() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetVariable(id); ok { + if _, ok := v.env.GetVariable(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } if ctx.DeclarationExpression() != nil { - v.Environ.Variable[id] = v.Visit(ctx.DeclarationExpression()) + v.env.Variable[id] = v.Visit(ctx.DeclarationExpression()) return nil } - v.Environ.Variable[id] = false + v.env.Variable[id] = false return nil default: return fmt.Errorf("scalar type=%s: %w", ctx.ScalarType().GetText(), ErrUnexpected) @@ -605,7 +604,7 @@ func (v *Visitor) VisitClassicalDeclarationStatement(ctx *parser.ClassicalDeclar func (v *Visitor) VisitDefStatement(ctx *parser.DefStatementContext) interface{} { name := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetSubroutine(name); ok { + if _, ok := v.env.GetSubroutine(name); ok { return fmt.Errorf("identifier=%s: %w", name, ErrAlreadyDeclared) } @@ -615,7 +614,7 @@ func (v *Visitor) VisitDefStatement(ctx *parser.DefStatementContext) interface{} qargs = append(qargs, a.(string)) } - v.Environ.Subroutine[name] = Subroutine{ + v.env.Subroutine[name] = Subroutine{ Name: name, QArgs: qargs, Body: ctx.Scope().(*parser.ScopeContext), @@ -627,12 +626,12 @@ func (v *Visitor) VisitDefStatement(ctx *parser.DefStatementContext) interface{} func (v *Visitor) VisitAliasDeclarationStatement(ctx *parser.AliasDeclarationStatementContext) interface{} { id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetQubit(id); ok { + if _, ok := v.env.GetQubit(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } alias := v.Visit(ctx.AliasExpression()).([]q.Qubit) - v.Environ.Qubit[id] = alias + v.env.Qubit[id] = alias return nil } @@ -640,7 +639,7 @@ func (v *Visitor) VisitOldStyleDeclarationStatement(ctx *parser.OldStyleDeclarat switch { case ctx.QREG() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetQubit(id); ok { + if _, ok := v.env.GetQubit(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } @@ -649,11 +648,11 @@ func (v *Visitor) VisitOldStyleDeclarationStatement(ctx *parser.OldStyleDeclarat size = v.Visit(ctx.Designator()).(int64) } - v.Environ.Qubit[id] = v.qsim.ZeroWith(int(size)) + v.env.Qubit[id] = v.qsim.ZeroWith(int(size)) return nil case ctx.CREG() != nil: id := v.Visit(ctx.Identifier()).(string) - if _, ok := v.Environ.GetClassicalBit(id); ok { + if _, ok := v.env.GetClassicalBit(id); ok { return fmt.Errorf("identifier=%s: %w", id, ErrAlreadyDeclared) } @@ -662,7 +661,7 @@ func (v *Visitor) VisitOldStyleDeclarationStatement(ctx *parser.OldStyleDeclarat size = v.Visit(ctx.Designator()).(int64) } - v.Environ.ClassicalBit[id] = make([]int64, int(size)) + v.env.ClassicalBit[id] = make([]int64, int(size)) return nil default: return fmt.Errorf("x=%s: %w", ctx.GetText(), ErrUnexpected) @@ -709,19 +708,19 @@ func (v *Visitor) VisitLiteralExpression(ctx *parser.LiteralExpressionContext) i return lit } - if lit, ok := v.Environ.GetConst(s); ok { + if lit, ok := v.env.GetConst(s); ok { return lit } - if lit, ok := v.Environ.GetVariable(s); ok { + if lit, ok := v.env.GetVariable(s); ok { return lit } - if lit, ok := v.Environ.GetQubit(s); ok { + if lit, ok := v.env.GetQubit(s); ok { return lit } - if lit, ok := v.Environ.GetClassicalBit(s); ok { + if lit, ok := v.env.GetClassicalBit(s); ok { return lit } @@ -1029,14 +1028,14 @@ func (v *Visitor) VisitCallExpression(ctx *parser.CallExpressionContext) interfa case "mod": return math.Mod(args[0].(float64), args[1].(float64)) default: - sub, ok := v.Environ.GetSubroutine(id) + sub, ok := v.env.GetSubroutine(id) if !ok { return fmt.Errorf("identifier=%s: %w", id, ErrFunctionNotFound) } enclosed := v.Enclosed() for i, p := range sub.QArgs { - enclosed.Environ.Qubit[p] = args[i].([]q.Qubit) + enclosed.env.Qubit[p] = args[i].([]q.Qubit) } result := enclosed.Visit(sub.Body).([]interface{}) @@ -1212,7 +1211,7 @@ func (v *Visitor) VisitGateOperand(ctx *parser.GateOperandContext) interface{} { indexID := ctx.IndexedIdentifier() operand := v.Visit(indexID.Identifier()).(string) - qb, ok := v.Environ.GetQubit(operand) + qb, ok := v.env.GetQubit(operand) if !ok { return fmt.Errorf("operand=%s: %w", operand, ErrQubitNotFound) } diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 1c01f13..18cc128 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -777,6 +777,7 @@ func TestVisitor_VisitMeasureArrowAssignmentStatement(t *testing.T) { case error: panic(ret) } + if len(c.want.classicalBit) != 0 { var found bool for _, w := range c.want.classicalBit {