diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 635648ed..07c6c4be 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -675,7 +675,7 @@ def _submit_request_and_process( Tuple with data and total number of docs in matching the query in the database. """ try: - response = self.session.get(url=url, verify=verify, params=params, timeout=timeout) + response = self.session.get(url=url, verify=verify, params=params, timeout=timeout, headers=self.headers) except requests.exceptions.ConnectTimeout: raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.") diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index a8427ff0..8d70e827 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -17,7 +17,7 @@ from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry from pymatgen.io.vasp import Chgcar from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from requests import get +from requests import get, Session from typing import Literal from mp_api.client.core import BaseRester, MPRestError @@ -85,6 +85,7 @@ def __init__( include_user_agent=True, monty_decode: bool = True, use_document_model: bool = True, + session: Session = None, headers: dict = None, ): """ @@ -116,6 +117,7 @@ def __init__( use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. + session (Session): Session object to use. By default (None), the client will create one. headers (dict): Custom headers for localhost connections. """ @@ -129,7 +131,7 @@ def __init__( self.api_key = api_key self.endpoint = endpoint self.headers = headers or {} - self.session = BaseRester._create_session( + self.session = session or BaseRester._create_session( api_key=api_key, include_user_agent=include_user_agent, headers=self.headers ) self.use_document_model = use_document_model @@ -138,7 +140,7 @@ def __init__( try: from mpcontribs.client import Client - self.contribs = Client(api_key, headers=self.headers) + self.contribs = Client(api_key, headers=self.headers, sesson=self.session) except ImportError: self.contribs = None warnings.warn( @@ -167,6 +169,7 @@ def __init__( session=self.session, monty_decode=monty_decode, use_document_model=use_document_model, + headers=self.headers ) # type: BaseRester self._all_resters.append(rester)