Skip to content

Commit

Permalink
Merge pull request #415 from FZJ-INM1-BDA/enableRetrievalTests
Browse files Browse the repository at this point in the history
HttpRequest.find_suitable_decoder returns None if no suitable decoder found
  • Loading branch information
AhmetNSimsek authored Jul 11, 2023
2 parents 8066bf9 + 0ff2652 commit dfce6ec
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 61 deletions.
4 changes: 2 additions & 2 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ addopts = -rf
testpaths =
# slowly add tests, until all tests are added
test/test_siibra.py

test/retrieval/
test/core/
test/volumes/

# eventually, only use the below ini
# test/
152 changes: 98 additions & 54 deletions siibra/retrieval/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@

from .cache import CACHE
from .exceptions import EbrainsAuthenticationError
from ..commons import logger, HBP_AUTH_TOKEN, KEYCLOAK_CLIENT_ID, KEYCLOAK_CLIENT_SECRET, siibra_tqdm, SIIBRA_USE_LOCAL_SNAPSPOT
from ..commons import (
logger,
HBP_AUTH_TOKEN,
KEYCLOAK_CLIENT_ID,
KEYCLOAK_CLIENT_SECRET,
siibra_tqdm,
SIIBRA_USE_LOCAL_SNAPSPOT,
)
from .. import __version__

import json
Expand Down Expand Up @@ -56,22 +63,19 @@
".txt": lambda b: pd.read_csv(BytesIO(b), delimiter=" ", header=None),
".zip": lambda b: ZipFile(BytesIO(b)),
".png": lambda b: skimage_io.imread(BytesIO(b)),
".npy": lambda b: np.load(BytesIO(b))
".npy": lambda b: np.load(BytesIO(b)),
}


class SiibraHttpRequestError(Exception):

def __init__(self, url: str, status_code: int, msg="Cannot execute http request."):
self.url = url
self.status_code = status_code
self.msg = msg
Exception.__init__(self)

def __str__(self):
return (
f"{self.msg}\n\tStatus code: {self.status_code}\n\tUrl: {self.url:76.76}"
)
return f"{self.msg}\n\tStatus code: {self.status_code}\n\tUrl: {self.url:76.76}"


class HttpRequest:
Expand Down Expand Up @@ -117,14 +121,18 @@ def __init__(
@staticmethod
def find_suitiable_decoder(url: str):
urlpath = urllib.parse.urlsplit(url).path
if urlpath.endswith('.gz'):
if urlpath.endswith(".gz"):
dec = HttpRequest.find_suitiable_decoder(urlpath[:-3])
return lambda b: dec(gzip.decompress(b))

suitable_decoders = [dec for sfx, dec in DECODERS.items() if urlpath.endswith(sfx)]
suitable_decoders = [
dec for sfx, dec in DECODERS.items() if urlpath.endswith(sfx)
]
if len(suitable_decoders) > 0:
assert len(suitable_decoders) == 1
return suitable_decoders[0]
else:
return None

def _set_decoder_func(self, func):
self.func = func or self.find_suitiable_decoder(self.url)
Expand All @@ -146,24 +154,34 @@ def _retrieve(self, block_size=1024, min_bytesize_with_no_progress_info=2e8):
if self.msg_if_not_cached is not None:
logger.debug(self.msg_if_not_cached)

headers = self.kwargs.get('headers', {})
other_kwargs = {key: self.kwargs[key] for key in self.kwargs if key != "headers"}
headers = self.kwargs.get("headers", {})
other_kwargs = {
key: self.kwargs[key] for key in self.kwargs if key != "headers"
}

http_method = requests.post if self.post else requests.get
r = http_method(self.url, headers={
**USER_AGENT_HEADER,
**headers,
}, **other_kwargs, stream=True)
r = http_method(
self.url,
headers={
**USER_AGENT_HEADER,
**headers,
},
**other_kwargs,
stream=True,
)

if not r.ok:
raise SiibraHttpRequestError(status_code=r.status_code, url=self.url)

size_bytes = int(r.headers.get('content-length', 0))
size_bytes = int(r.headers.get("content-length", 0))
if size_bytes > min_bytesize_with_no_progress_info:
progress_bar = siibra_tqdm(
total=size_bytes, unit='iB', unit_scale=True,
position=0, leave=True,
desc=f"Downloading {os.path.split(self.url)[-1]} ({size_bytes / 1024**2:.1f} MiB)"
total=size_bytes,
unit="iB",
unit_scale=True,
position=0,
leave=True,
desc=f"Downloading {os.path.split(self.url)[-1]} ({size_bytes / 1024**2:.1f} MiB)",
)
temp_cachefile = f"{self.cachefile}_temp"
lock = Lock(f"{temp_cachefile}.lock")
Expand Down Expand Up @@ -204,8 +222,8 @@ def data(self):


class ZipfileRequest(HttpRequest):
def __init__(self, url, filename, func=None):
HttpRequest.__init__(self, url, func=func)
def __init__(self, url, filename, func=None, refresh=False):
HttpRequest.__init__(self, url, func=func, refresh=refresh)
self.filename = filename
self._set_decoder_func(self.find_suitiable_decoder(self.filename))

Expand Down Expand Up @@ -259,16 +277,24 @@ def init_oidc(cls):
resp = requests.get(f"{cls._IAM_ENDPOINT}/.well-known/openid-configuration")
json_resp = resp.json()
if "token_endpoint" in json_resp:
logger.debug(f"token_endpoint exists in .well-known/openid-configuration. Setting _IAM_TOKEN_ENDPOINT to {json_resp.get('token_endpoint')}")
logger.debug(
f"token_endpoint exists in .well-known/openid-configuration. Setting _IAM_TOKEN_ENDPOINT to {json_resp.get('token_endpoint')}"
)
cls._IAM_TOKEN_ENDPOINT = json_resp.get("token_endpoint")
else:
logger.warning("expect token endpoint in .well-known/openid-configuration, but was not present")
logger.warning(
"expect token endpoint in .well-known/openid-configuration, but was not present"
)

if "device_authorization_endpoint" in json_resp:
logger.debug(f"device_authorization_endpoint exists in .well-known/openid-configuration. setting _IAM_DEVICE_ENDPOINT to {json_resp.get('device_authorization_endpoint')}")
logger.debug(
f"device_authorization_endpoint exists in .well-known/openid-configuration. setting _IAM_DEVICE_ENDPOINT to {json_resp.get('device_authorization_endpoint')}"
)
cls._IAM_DEVICE_ENDPOINT = json_resp.get("device_authorization_endpoint")
else:
logger.warning("expected device_authorization_endpoint in .well-known/openid-configuration, but was not present")
logger.warning(
"expected device_authorization_endpoint in .well-known/openid-configuration, but was not present"
)

@classmethod
def fetch_token(cls, **kwargs):
Expand All @@ -283,11 +309,17 @@ def fetch_token(cls, **kwargs):

@classmethod
def device_flow(cls, **kwargs):
if all([
not sys.__stdout__.isatty(), # if is tty, do not raise
not any(k in ['JPY_INTERRUPT_EVENT', "JPY_PARENT_PID"] for k in os.environ), # if is notebook environment, do not raise
not os.getenv("SIIBRA_ENABLE_DEVICE_FLOW"), # if explicitly enabled by env var, do not raise
]):
if all(
[
not sys.__stdout__.isatty(), # if is tty, do not raise
not any(
k in ["JPY_INTERRUPT_EVENT", "JPY_PARENT_PID"] for k in os.environ
), # if is notebook environment, do not raise
not os.getenv(
"SIIBRA_ENABLE_DEVICE_FLOW"
), # if explicitly enabled by env var, do not raise
]
):
raise EbrainsAuthenticationError(
"sys.stdout is not tty, SIIBRA_ENABLE_DEVICE_FLOW is not set,"
"and not running in a notebook. Are you running in batch mode?"
Expand All @@ -306,22 +338,17 @@ def get_scopes() -> str:
logger.warning("scopes needs to be all str, but is not")
return None
if len(scopes) == 0:
logger.warning('provided empty list as scopes... skipping')
logger.warning("provided empty list as scopes... skipping")
return None
return "+".join(scopes)

scopes = get_scopes()

data = {
'client_id': cls._IAM_DEVICE_FLOW_CLIENTID
}
data = {"client_id": cls._IAM_DEVICE_FLOW_CLIENTID}

if scopes:
data['scopes'] = scopes
resp = requests.post(
url=cls._IAM_DEVICE_ENDPOINT,
data=data
)
data["scopes"] = scopes
resp = requests.post(url=cls._IAM_DEVICE_ENDPOINT, data=data)
resp.raise_for_status()
resp_json = resp.json()
logger.debug("device flow, request full json:", resp_json)
Expand All @@ -344,17 +371,19 @@ def get_scopes() -> str:

logger.debug("Calling endpoint")
if attempt_number > cls._IAM_DEVICE_MAXTRIES:
message = f"exceeded max attempts: {cls._IAM_DEVICE_MAXTRIES}, aborting..."
message = (
f"exceeded max attempts: {cls._IAM_DEVICE_MAXTRIES}, aborting..."
)
logger.error(message)
raise EbrainsAuthenticationError(message)
attempt_number += 1
resp = requests.post(
url=cls._IAM_TOKEN_ENDPOINT,
data={
'grant_type': "urn:ietf:params:oauth:grant-type:device_code",
'client_id': cls._IAM_DEVICE_FLOW_CLIENTID,
'device_code': device_code
}
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"client_id": cls._IAM_DEVICE_FLOW_CLIENTID,
"device_code": device_code,
},
)

if resp.status_code == 200:
Expand All @@ -381,7 +410,6 @@ def set_token(cls, token):

@property
def kg_token(self):

# token is available, return it
if self.__class__._KG_API_TOKEN is not None:
return self.__class__._KG_API_TOKEN
Expand Down Expand Up @@ -448,7 +476,7 @@ def get(self):
def try_all_connectors():
def outer(fn):
@wraps(fn)
def inner(self: 'GitlabProxyEnum', *args, **kwargs):
def inner(self: "GitlabProxyEnum", *args, **kwargs):
exceptions = []
for connector in self.connectors:
try:
Expand All @@ -459,7 +487,9 @@ def inner(self: 'GitlabProxyEnum', *args, **kwargs):
for exc in exceptions:
logger.error(exc)
raise Exception("try_all_connectors failed")

return inner

return outer


Expand All @@ -470,30 +500,40 @@ class GitlabProxyEnum(Enum):
DATASETVERSION_V3 = "DATASETVERSION_V3"

@property
def connectors(self) -> List['GitlabConnector']:
def connectors(self) -> List["GitlabConnector"]:
servers = [
("https://jugit.fz-juelich.de", 7846),
("https://gitlab.ebrains.eu", 421),
]
from .repositories import GitlabConnector, LocalFileRepository

if SIIBRA_USE_LOCAL_SNAPSPOT:
logger.info(f"Using localsnapshot at {SIIBRA_USE_LOCAL_SNAPSPOT}")
return [LocalFileRepository(SIIBRA_USE_LOCAL_SNAPSPOT)]
return [GitlabConnector(server[0], server[1], "master", archive_mode=True) for server in servers]
return [
GitlabConnector(server[0], server[1], "master", archive_mode=True)
for server in servers
]

@try_all_connectors()
def search_files(self, folder: str, suffix=None, recursive=True, *, connector: 'GitlabConnector' = None) -> List[str]:
def search_files(
self,
folder: str,
suffix=None,
recursive=True,
*,
connector: "GitlabConnector" = None,
) -> List[str]:
assert connector
return connector.search_files(folder, suffix=suffix, recursive=recursive)

@try_all_connectors()
def get(self, filename, decode_func=None, *, connector: 'GitlabConnector' = None):
def get(self, filename, decode_func=None, *, connector: "GitlabConnector" = None):
assert connector
return connector.get(filename, "", decode_func)


class GitlabProxy(HttpRequest):

folder_dict = {
GitlabProxyEnum.DATASET_V1: "ebrainsquery/v1/dataset",
GitlabProxyEnum.DATASET_V3: "ebrainsquery/v3/Dataset",
Expand All @@ -505,11 +545,11 @@ def __init__(
self,
flavour: GitlabProxyEnum,
instance_id=None,
postprocess: Callable[['GitlabProxy', Any], Any] = (
postprocess: Callable[["GitlabProxy", Any], Any] = (
lambda proxy, obj: obj
if hasattr(proxy, "instance_id") and proxy.instance_id
else {"results": obj}
)
),
):
if flavour not in GitlabProxyEnum:
raise RuntimeError("Can only proxy enum members")
Expand All @@ -522,7 +562,9 @@ def __init__(

def get(self):
if self.instance_id:
return self.postprocess(self, self.flavour.get(f"{self.folder}/{self.instance_id}.json"))
return self.postprocess(
self, self.flavour.get(f"{self.folder}/{self.instance_id}.json")
)
return self.postprocess(self, self.flavour.get(f"{self.folder}/_all.json"))


Expand All @@ -544,7 +586,9 @@ def get(self):
except Exception as e:
exceptions.append(e)
else:
raise MultiSourceRequestException("All requests failed:\n" + "\n".join(str(exc) for exc in exceptions))
raise MultiSourceRequestException(
"All requests failed:\n" + "\n".join(str(exc) for exc in exceptions)
)

@property
def data(self):
Expand Down
8 changes: 3 additions & 5 deletions test/retrieval/test_retrieval_download_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

class TestretrievalDownloadFile(unittest.TestCase):
def test_download_zipped_file(self):
url = "https://object.cscs.ch/v1/AUTH_227176556f3c4bb38df9feea4b91200c/test_stefan_destination/MPM.zip"
ziptarget = "JulichBrain_v25.xml"
loader = siibra.retrieval.requests.ZipfileRequest(url, ziptarget)
url = "https://data-proxy.ebrains.eu/api/v1/buckets/d-37258979-8b9f-4817-9d83-f009019a6c38/Semi-quantitative-analysis-siibra-csv.zip"
ziptarget = "F9-BDA.csv"
loader = siibra.retrieval.requests.ZipfileRequest(url, ziptarget, refresh=True)
self.assertIsNotNone(loader.data)

# TODO Clear cache folder after test (for local testing)


if __name__ == "__main__":
unittest.main()

0 comments on commit dfce6ec

Please sign in to comment.