Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Escape MAST Download URIs #3080

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
import time
import os
from urllib.parse import quote

import numpy as np

Expand Down Expand Up @@ -534,6 +535,7 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
# create the full data URL
base_url = base_url if base_url else self._portal_api_connection.MAST_DOWNLOAD_URL
data_url = base_url + "?uri=" + uri
escaped_url = base_url + "?uri=" + quote(uri, safe=":/")

# parse a local file path from local_path parameter. Use current directory as default.
filename = os.path.basename(uri)
Expand Down Expand Up @@ -565,11 +567,11 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
status = "SKIPPED"
else:
log.warning("Falling back to mast download...")
self._download_file(data_url, local_path,
self._download_file(escaped_url, local_path,
cache=cache, head_safe=True, continuation=False,
verbose=verbose)
else:
self._download_file(data_url, local_path,
self._download_file(escaped_url, local_path,
cache=cache, head_safe=True, continuation=False,
verbose=verbose)

Expand Down
51 changes: 13 additions & 38 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ def test_mast_service_request_async(self):
assert isinstance(responses, list)

def test_mast_service_request(self):

# clear columns config
Mast._column_configs = dict()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like this and the other assignments to dict() were done to reset the class attribute _column_configs. Are we sure we don't need this anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure what purpose the lines serve, but removing them has no impact on the test results. From what I can tell, _column_configs seems to be a sort of caching variable to avoid having to make multiple calls for the columnsConfig entry for a service. I don't see any apparent reason that this should be reset for certain tests, since I don't imagine that we expect the configuration for a service to change all that often. What do you think?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm fine with removing from the tests


service = 'Mast.Caom.Cone'
params = {'ra': 184.3,
'dec': 54.5,
Expand All @@ -174,9 +170,6 @@ def test_mast_service_request(self):
assert len(result[np.where(result["obs_id"] == "6374399093149532160")]) == 2

def test_mast_query(self):
# clear columns config
Mast._column_configs = dict()

result = Mast.mast_query('Mast.Caom.Cone', ra=184.3, dec=54.5, radius=0.2)

# Is result in the right format
Expand Down Expand Up @@ -225,9 +218,6 @@ def test_observations_query_region_async(self):
assert isinstance(responses, list)

def test_observations_query_region(self):
# clear columns config
Observations._column_configs = dict()

result = Observations.query_region("322.49324 12.16683", radius="0.005 deg")
assert isinstance(result, Table)
assert len(result) > 500
Expand All @@ -243,9 +233,6 @@ def test_observations_query_object_async(self):
assert isinstance(responses, list)

def test_observations_query_object(self):
# clear columns config
Observations._column_configs = dict()

result = Observations.query_object("M8", radius=".04 deg")
assert isinstance(result, Table)
assert len(result) > 150
Expand All @@ -264,10 +251,6 @@ def test_observations_query_criteria_async(self):
assert isinstance(responses, list)

def test_observations_query_criteria(self):

# clear columns config
Observations._column_configs = dict()

# without position
result = Observations.query_criteria(instrument_name="*WFPC2*",
proposal_id=8169,
Expand Down Expand Up @@ -333,10 +316,6 @@ def test_observations_get_product_list_async(self):
assert isinstance(responses, list)

def test_observations_get_product_list(self):

# clear columns config
Observations._column_configs = dict()

observations = Observations.query_object("M8", radius=".04 deg")
test_obs_id = str(observations[0]['obsid'])
mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid'])
Expand Down Expand Up @@ -519,6 +498,19 @@ def test_observations_download_file_cloud(self, tmp_path, in_uri):
assert result == ('COMPLETE', None, None)
assert Path(tmp_path, filename).exists()

def test_observations_download_file_escaped(self, tmp_path):
# test that `download_file` correctly escapes a URI
in_uri = 'mast:HLA/url/cgi-bin/fitscut.cgi?' \
'red=hst_04819_65_wfpc2_f814w_pc&blue=hst_04819_65_wfpc2_f555w_pc&size=ALL&format=fits'
filename = Path(in_uri).name
result = Observations.download_file(uri=in_uri, local_path=tmp_path)
assert result == ('COMPLETE', None, None)
assert Path(tmp_path, filename).exists()

# check that downloaded file is a valid FITS file
f = fits.open(Path(tmp_path, filename))
f.close()

@pytest.mark.parametrize("test_data_uri, expected_cloud_uri", [
("mast:HST/product/u24r0102t_c1f.fits",
"s3://stpubdata/hst/public/u24r/u24r0102t/u24r0102t_c1f.fits"),
Expand Down Expand Up @@ -618,10 +610,7 @@ def check_result(result, row, exp_values):
for k, v in exp_values.items():
assert result[row][k] == v

# clear columns config
Catalogs._column_configs = dict()
in_radius = 0.1 * u.deg

result = Catalogs.query_region("158.47924 -7.30962",
radius=in_radius,
catalog="Gaia")
Expand Down Expand Up @@ -717,9 +706,6 @@ def check_result(result, exp_values):
for k, v in exp_values.items():
assert v in result[k]

# clear columns config
Catalogs._column_configs = dict()

result = Catalogs.query_object("M10",
radius=.001,
catalog="TIC")
Expand Down Expand Up @@ -819,9 +805,6 @@ def check_result(result, exp_vals):
for k, v in exp_vals.items():
assert v in result[k]

# clear columns config
Catalogs._column_configs = dict()

# without position
result = Catalogs.query_criteria(catalog="Tic",
Bmag=[30, 50],
Expand Down Expand Up @@ -897,10 +880,6 @@ def test_catalogs_query_hsc_matchid_async(self):
assert isinstance(responses, list)

def test_catalogs_query_hsc_matchid(self):

# clear columns config
Catalogs._column_configs = dict()

catalogData = Catalogs.query_object("M10",
radius=.001,
catalog="HSC",
Expand All @@ -921,10 +900,6 @@ def test_catalogs_get_hsc_spectra_async(self):
assert isinstance(responses, list)

def test_catalogs_get_hsc_spectra(self):

# clear columns config
Catalogs._column_configs = dict()

result = Catalogs.get_hsc_spectra()
assert isinstance(result, Table)
assert result[np.where(result['MatchID'] == '19657846')]
Expand Down
Loading