Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial adaptations for Server API version 0.3.0 #1106

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
407 changes: 406 additions & 1 deletion cirq-superstaq/cirq_superstaq/job.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions cirq-superstaq/cirq_superstaq/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def job() -> css.Job:
Returns:
A `cirq_superstaq` Job instance.
"""
client = gss.superstaq_client._SuperstaqClient(
client = gss.superstaq_client.get_client(
client_name="cirq-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
assert isinstance(client, gss.superstaq_client._SuperstaqClient)
return css.Job(client, "job_id")


Expand Down Expand Up @@ -76,11 +77,12 @@ def new_job() -> css.Job:
Returns:
A `cirq_superstaq` Job instance.
"""
client = gss.superstaq_client._SuperstaqClient(
client = gss.superstaq_client.get_client(
client_name="cirq-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
assert isinstance(client, gss.superstaq_client._SuperstaqClient)
return css.Job(client, "new_job_id")


Expand All @@ -91,11 +93,12 @@ def multi_circuit_job() -> css.Job:
Returns:
A job with multiple subjobs
"""
client = gss.superstaq_client._SuperstaqClient(
client = gss.superstaq_client.get_client(
client_name="cirq-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
assert isinstance(client, gss.superstaq_client._SuperstaqClient)
job = css.Job(client, "job_id1,job_id2,job_id3")
job._job = {
"job_id1": {"status": "Done"},
Expand Down
13 changes: 9 additions & 4 deletions cirq-superstaq/cirq_superstaq/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
EnvironmentError: If an API key was not provided and could not be found.
"""
self.default_target = default_target
self._client = superstaq_client._SuperstaqClient(
self._client = superstaq_client.get_client(
client_name="cirq-superstaq",
remote_host=remote_host,
api_key=api_key,
Expand Down Expand Up @@ -354,9 +354,12 @@ def create_job(
method=method,
**kwargs,
)
# Make a virtual job_id that aggregates all of the individual jobs
# into a single one that comma-separates the individual jobs.
job_id = ",".join(result["job_ids"])
if isinstance(self._client, gss.superstaq_client._SuperstaqClient_v0_3_0):
job_id = result["job_id"]
else:
# Make a virtual job_id that aggregates all of the individual jobs
# into a single one that comma-separates the individual jobs.
job_id = ",".join(result["job_ids"])

# The returned job does not have fully populated fields; they will be filled out by
# when the new job's status is first queried
Expand All @@ -375,6 +378,8 @@ def get_job(self, job_id: str) -> css.job.Job:
Raises:
~gss.SuperstaqServerException: If there was an error accessing the API.
"""
if isinstance(self._client, gss.superstaq_client._SuperstaqClient_v0_3_0):
return css.job._Job(client=self._client, job_id=job_id)
return css.job.Job(client=self._client, job_id=job_id)

def resource_estimate(
Expand Down
50 changes: 50 additions & 0 deletions cirq-superstaq/cirq_superstaq/service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,25 @@ def test_service_get_job() -> None:
mock_client.fetch_jobs.assert_called_once_with(["job_id"])


def test_service_get_job_v0_3_0() -> None:
service = css.Service(api_key="key", remote_host="http://example.com", api_version="v0.3.0")
mock_client = mock.MagicMock(
spec=gss.superstaq_client._SuperstaqClient_v0_3_0, api_version="v0.3.0"
)
job_dict = mock.MagicMock(
spec=gss._models.JobData,
statuses=[gss._models.CircuitStatus.RUNNING, gss._models.CircuitStatus.AWAITING_SUBMISSION],
)
mock_client.fetch_single_job.return_value = job_dict
service._client = mock_client

job = service.get_job("job_id")

# fetch_jobs() should not be called upon construction
assert job.job_id() == "job_id"
assert job._job_data == job_dict # type: ignore[attr-defined]


def test_service_create_job() -> None:
service = css.Service(api_key="key", remote_host="http://example.com")
mock_client = mock.MagicMock()
Expand All @@ -247,6 +266,37 @@ def test_service_create_job() -> None:
assert create_job_kwargs["fake_data"] == ""


def test_service_create_job_v0_3_0() -> None:
service = css.Service(api_key="key", remote_host="http://example.com", api_version="v0.3.0")
mock_client = mock.MagicMock(
api_version="v0.3.0", spec=gss.superstaq_client._SuperstaqClient_v0_3_0
)
mock_client.create_job.return_value = mock.MagicMock(
job_id="job_id", statuses=[gss._models.CircuitStatus.RECEIVED]
)

mock_client.fetch_single_job.return_value = mock.MagicMock(
job_id="job_id", statuses=[gss._models.CircuitStatus.COMPLETED]
)
service._client = mock_client

circuit = cirq.Circuit(cirq.X(cirq.LineQubit(0)), cirq.measure(cirq.LineQubit(0)))
job = service.create_job(
circuits=circuit,
repetitions=100,
target="ss_fake_qpu",
method="fake_method",
fake_data="",
)
assert job.status() == "completed"
create_job_kwargs = mock_client.create_job.call_args[1]
# Serialization induces a float, so we don't validate full circuit.
assert create_job_kwargs["repetitions"] == 100
assert create_job_kwargs["target"] == "ss_fake_qpu"
assert create_job_kwargs["method"] == "fake_method"
assert create_job_kwargs["fake_data"] == ""


@mock.patch(
"general_superstaq.superstaq_client._SuperstaqClient.post_request",
return_value={
Expand Down
11 changes: 10 additions & 1 deletion general-superstaq/general_superstaq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
)
from general_superstaq.typing import Target

from . import serialization, service, superstaq_client, superstaq_exceptions, typing, validation
from . import (
_models,
serialization,
service,
superstaq_client,
superstaq_exceptions,
typing,
validation,
)

__all__ = [
"__version__",
Expand All @@ -27,4 +35,5 @@
"typing",
"Target",
"validation",
"_models",
]
26 changes: 24 additions & 2 deletions general-superstaq/general_superstaq/_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""Data models used when communicating with the superstaq server."""

# pragma: no cover
from __future__ import annotations

import datetime
import uuid
from collections.abc import Sequence
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING, Any

import pydantic

if TYPE_CHECKING:
from typing_extensions import Self


class JobType(str, Enum):
"""The different types of jobs that can be submitted through Superstaq."""
Expand Down Expand Up @@ -151,6 +153,26 @@ class JobData(DefaultPydanticModel):
final_logical_to_physicals: list[str | None]
"""Serialized initial final-to-physical mapping for each circuit."""

@pydantic.model_validator(mode="after")
def validate_consistent_number_of_circuits(self) -> Self:
"""Checks that all lists contain the correct number of elements
(equal to the number of circuits).

Raises:
ValueError: If any list attribute has the wrong number of elements

Returns:
The validated model.
"""
for name, attr in self.model_dump().items():
if isinstance(attr, list):
if len(attr) != self.num_circuits:
raise ValueError(
f"Field {name} does not contain the correct number of elements. "
f"Expected {self.num_circuits} but found {len(attr)}."
)
return self


class NewJob(DefaultPydanticModel):
"""The data model for submitting new jobs."""
Expand Down
44 changes: 44 additions & 0 deletions general-superstaq/general_superstaq/_models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# pylint: disable=missing-function-docstring,missing-class-docstring,missing-return-doc,
# pylint: disable=unused-argument
from __future__ import annotations

import datetime

import pydantic
import pytest

import general_superstaq as gss


def test_job_data_list_validation() -> None:
with pytest.raises(
pydantic.ValidationError,
match=(
"Field input_circuits does not contain the correct number of elements. "
"Expected 1 but found 2."
),
):
gss._models.JobData(
job_type=gss._models.JobType.DEVICE_SUBMISSION,
statuses=[gss._models.CircuitStatus.RUNNING],
status_messages=[None],
user_email="[email protected]",
target="ss_example_qpu",
provider_id=["example_id"],
num_circuits=1,
compiled_circuit_type=gss._models.CircuitType.CIRQ,
compiled_circuits=["compiled_circuit"],
input_circuits=["input_circuits", "extra_circuit!"],
input_circuit_type=gss._models.CircuitType.QISKIT,
pulse_gate_circuits=[None],
counts=[None],
state_vectors=[None],
results_dicts=[None],
num_qubits=[3],
shots=[100],
dry_run=False,
submission_timestamp=datetime.datetime(2000, 1, 1, 0, 0, 0),
last_updated_timestamp=[datetime.datetime(2000, 1, 1, 0, 1, 0)],
initial_logical_to_physicals=[None],
final_logical_to_physicals=[None],
)
2 changes: 1 addition & 1 deletion general-superstaq/general_superstaq/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
client: The Superstaq client to use.
"""

self._client = gss.superstaq_client._SuperstaqClient(
self._client = gss.superstaq_client.get_client(
client_name="general-superstaq",
remote_host=remote_host,
api_key=api_key,
Expand Down
Loading
Loading