Skip to content

Commit

Permalink
Custom sessions (#716)
Browse files Browse the repository at this point in the history
* Add ability to pass custom session to MPRester

* Add custom headers to each request
  • Loading branch information
Jason Munro authored Dec 8, 2022
1 parent 6bf1791 commit ac4bb36
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def _submit_request_and_process(
Tuple with data and total number of docs in matching the query in the database.
"""
try:
response = self.session.get(url=url, verify=verify, params=params, timeout=timeout)
response = self.session.get(url=url, verify=verify, params=params, timeout=timeout, headers=self.headers)
except requests.exceptions.ConnectTimeout:
raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.")

Expand Down
9 changes: 6 additions & 3 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from pymatgen.io.vasp import Chgcar
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from requests import get
from requests import get, Session
from typing import Literal

from mp_api.client.core import BaseRester, MPRestError
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
include_user_agent=True,
monty_decode: bool = True,
use_document_model: bool = True,
session: Session = None,
headers: dict = None,
):
"""
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
use_document_model: If False, skip the creating the document model and return data
as a dictionary. This can be simpler to work with but bypasses data validation
and will not give auto-complete for available fields.
session (Session): Session object to use. By default (None), the client will create one.
headers (dict): Custom headers for localhost connections.
"""

Expand All @@ -129,7 +131,7 @@ def __init__(
self.api_key = api_key
self.endpoint = endpoint
self.headers = headers or {}
self.session = BaseRester._create_session(
self.session = session or BaseRester._create_session(
api_key=api_key, include_user_agent=include_user_agent, headers=self.headers
)
self.use_document_model = use_document_model
Expand All @@ -138,7 +140,7 @@ def __init__(
try:
from mpcontribs.client import Client

self.contribs = Client(api_key, headers=self.headers)
self.contribs = Client(api_key, headers=self.headers, sesson=self.session)
except ImportError:
self.contribs = None
warnings.warn(
Expand Down Expand Up @@ -167,6 +169,7 @@ def __init__(
session=self.session,
monty_decode=monty_decode,
use_document_model=use_document_model,
headers=self.headers
) # type: BaseRester

self._all_resters.append(rester)
Expand Down

0 comments on commit ac4bb36

Please sign in to comment.