Skip to content

Commit

Permalink
Specify cache key via protocol (#80)
Browse files Browse the repository at this point in the history
* Specify cache key via protocol

* formatting, v0.3.4
  • Loading branch information
PhilReinhold authored Oct 30, 2023
1 parent 25c3033 commit 43d4456
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
16 changes: 12 additions & 4 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
from __future__ import annotations

import math
import uuid
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Iterable,
Optional,
Protocol,
Expand Down Expand Up @@ -351,6 +353,8 @@ class CachedExpressionConvertible(Protocol):
no guarantees are made about this.
"""

_oqpy_cache_key: Hashable

def _to_cached_oqpy_expression(self) -> HasToAst:
... # pragma: no cover

Expand Down Expand Up @@ -469,10 +473,14 @@ def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
item = cast(ExpressionConvertible, item)
return item._to_oqpy_expression().to_ast(program)
if hasattr(item, "_to_cached_oqpy_expression"):
if id(item) not in program.expr_cache:
item = cast(CachedExpressionConvertible, item)
program.expr_cache[id(item)] = item._to_cached_oqpy_expression().to_ast(program)
return program.expr_cache[id(item)]
item = cast(CachedExpressionConvertible, item)
if item._oqpy_cache_key is None:
item._oqpy_cache_key = uuid.uuid1()
if item._oqpy_cache_key not in program.expr_cache:
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression().to_ast(
program
)
return program.expr_cache[item._oqpy_cache_key]
if isinstance(item, (complex, np.complexfloating)):
if item.imag == 0:
return to_ast(program, item.real)
Expand Down
4 changes: 2 additions & 2 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import warnings
from copy import deepcopy
from typing import Any, Iterable, Iterator, Optional
from typing import Any, Hashable, Iterable, Iterator, Optional

from openpulse import ast
from openpulse.printer import dumps
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, version: Optional[str] = "3.0", simplify_constants: bool = Tr
self.simplify_constants = simplify_constants
self.declared_subroutines: set[str] = set()
self.declared_gates: set[str] = set()
self.expr_cache: dict[int, ast.Expression] = {}
self.expr_cache: dict[Hashable, ast.Expression] = {}
"""A cache of ast made by CachedExpressionConvertible objects used in this program.
This is used by `to_ast` to avoid repetitively evaluating ast conversion methods.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "oqpy"
version = "0.3.3"
version = "0.3.4"
description = "Generating OpenQASM 3 + OpenPulse in Python"
authors = ["OQpy Contributors <[email protected]>"]
license = "Apache-2.0"
Expand Down
2 changes: 2 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,7 @@ def test_cached_expression_convertible():
class A:
name: str
count: int = 0
_oqpy_cache_key = None

def _to_cached_oqpy_expression(self):
self.count += 1
Expand All @@ -1584,6 +1585,7 @@ def _to_cached_oqpy_expression(self):
class F:
name: str
count: int = 0
_oqpy_cache_key = None

def _to_cached_oqpy_expression(self):
self.count += 1
Expand Down

0 comments on commit 43d4456

Please sign in to comment.