diff --git a/planqk/qiskit/client/client.py b/planqk/qiskit/client/client.py index e91ad55..cc06424 100644 --- a/planqk/qiskit/client/client.py +++ b/planqk/qiskit/client/client.py @@ -46,6 +46,10 @@ def set_credentials(cls, credentials: DefaultCredentialsProvider): def get_credentials(cls): return cls._credentials + @classmethod + def set_organization_id(cls, organization_id: str): + cls._organization_id = organization_id + @classmethod def perform_request(cls, request_func: Callable[..., Response], url: str, params=None, data=None, headers=None): headers = {**cls._get_default_headers(), **(headers or {})} @@ -140,7 +144,10 @@ def _get_default_headers(cls): cls._context_resolver = ContextResolver() context = cls._context_resolver.get_context() - if context is not None and context.is_organization: + + if cls._organization_id is not None: + headers["x-organizationid"] = cls._organization_id + elif context is not None and context.is_organization: headers["x-organizationid"] = context.get_organization_id() headers[HEADER_CLOUD_TRACE_CTX] = cls._generate_trace_id() diff --git a/planqk/qiskit/provider.py b/planqk/qiskit/provider.py index fe12cd4..c1029fc 100644 --- a/planqk/qiskit/provider.py +++ b/planqk/qiskit/provider.py @@ -12,7 +12,7 @@ class PlanqkQuantumProvider(Provider): - def __init__(self, access_token: str = None): + def __init__(self, access_token: str = None, organization_id: str = None): """Initialize the PlanQK provider. Args: access_token (str): access token used for authentication with PlanQK. If not token is provided, @@ -20,6 +20,7 @@ def __init__(self, access_token: str = None): manually or by using the PlanQK CLI. """ _PlanqkClient.set_credentials(DefaultCredentialsProvider(access_token)) + _PlanqkClient.set_organization_id(organization_id) def backends(self, provider: PROVIDER = None, **kwargs): """ diff --git a/planqk/qiskit/runtime_provider.py b/planqk/qiskit/runtime_provider.py index 5d4fa4b..571da43 100644 --- a/planqk/qiskit/runtime_provider.py +++ b/planqk/qiskit/runtime_provider.py @@ -19,8 +19,11 @@ class PlanqkQiskitRuntimeService(PlanqkQuantumProvider): - def __init__(self, access_token=None, channel: Optional[ChannelType] = None, channel_strategy=None): - super().__init__(access_token) + def __init__(self, access_token: Optional[str] = None, + organization_id: Optional[str] = None, + channel: Optional[ChannelType] = None, + channel_strategy=None): + super().__init__(access_token, organization_id) self._channel = channel self._channel_strategy = channel_strategy @@ -73,7 +76,7 @@ def run(self, qrt_options.validate(channel=self.channel) hgp_name = 'ibm-q/open/main' - + runtime_job_params = RuntimeJobParamsDto( program_id=program_id, image=qrt_options.image, diff --git a/tests/integration/test_context.py b/tests/integration/test_context.py index edaa607..186591e 100644 --- a/tests/integration/test_context.py +++ b/tests/integration/test_context.py @@ -3,6 +3,23 @@ import unittest.mock from planqk.context import ContextResolver +from planqk.qiskit import PlanqkQuantumProvider +from planqk.qiskit.client.client import _PlanqkClient + + +def _create_context_env_file(): + json_value = """ + { + "context": { + "id": "c557000f-f2b1-4505-8172-dac7960caf16", + "displayName": "Test Org", + "isOrganization": true + } + } + """ + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as fp: + fp.write(json_value.encode("utf-8")) + os.environ["PLANQK_CONFIG_FILE_PATH"] = os.path.abspath(fp.name) class ContextResolverTestSuite(unittest.TestCase): @@ -37,18 +54,7 @@ def test_should_get_organization_id_from_context_when_env_var_set(self): self.assertEqual(context.get_organization_id(), "c557000f-f2b1-4505-8172-dac7960caf15") def test_should_get_organization_id_from_context(self): - json_value = """ - { - "context": { - "id": "c557000f-f2b1-4505-8172-dac7960caf16", - "displayName": "Test Org", - "isOrganization": true - } - } - """ - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as fp: - fp.write(json_value.encode("utf-8")) - os.environ["PLANQK_CONFIG_FILE_PATH"] = os.path.abspath(fp.name) + _create_context_env_file() context_resolver = ContextResolver() context = context_resolver.get_context() @@ -98,3 +104,10 @@ def test_should_return_none_when_file_is_empty(self): context = context_resolver.get_context() self.assertIsNone(context) + + def test_should_use_user_provided_org_id(self): + _create_context_env_file() + access_token = "user_access_token" + user_org_id = "user_org_id" + PlanqkQuantumProvider(access_token, user_org_id) + self.assertEqual(_PlanqkClient._get_default_headers()["x-organizationid"], user_org_id)