Skip to content

Commit

Permalink
Make mypy happy
Browse files Browse the repository at this point in the history
  • Loading branch information
ntellis committed Jan 26, 2024
1 parent 5347e70 commit 12b8650
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions cutouts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import pathlib
import sys
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -34,7 +34,7 @@ def get_cutouts(
download_full_image: bool = False,
compare: bool = False,
compare_kwargs: Optional[dict] = None,
) -> Tuple[Iterable[Dict[str, Any]], Iterable[Dict[str, Any]]]:
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
""" """

# Get urls and metadata for each cutout
Expand All @@ -47,35 +47,38 @@ def get_cutouts(
"same_filter": True,
}

results = []
comparison_results = []
results: List[Dict[str, Any]] = []
comparison_results: List[Dict[str, Any]] = []
for record in cutout_requests.to_dict(orient="records"):
try:
cutout_request = CutoutRequest(**record) # type: ignore
results_df = find_cutouts(cutout_request)
result = select_cutout(results_df, cutout_request)
result = dict(select_cutout(results_df, cutout_request))

except Exception as e:
logger.warning(e)
result = {"error": e}

if compare:
try:
comparison_result = select_comparison_cutout(
results_df, result, cutout_request, **compare_kwargs
comparison_result = dict(
select_comparison_cutout(
results_df,
CutoutRequest(result),
cutout_request,
**compare_kwargs,
)
)

except Exception as e:
logger.warning(e)
comparison_result = {"error": e}

comparison_result = dict(comparison_result)
comparison_results.append(comparison_result)

else:
comparison_results.append(None)
comparison_results.append({"error": "Comparison cutout not requested"})

result = dict(result)
results.append(result)

for result in results:
Expand Down Expand Up @@ -128,19 +131,19 @@ def get_cutouts(

# Download comparison cutouts
if compare:
for result in comparison_results:
if result is not None:
if "error" in result:
for comp_result in comparison_results:
if comp_result is not None:
if "error" in comp_result:
continue

(
comparison_full_image_path,
comparison_cutout_image_path,
) = generate_local_image_paths(result)
result["cutout_image_path"] = (
comp_result["cutout_image_path"] = (
pathlib.Path(out_dir) / "comparison" / comparison_cutout_image_path
)
path = result["cutout_image_path"]
path = comp_result["cutout_image_path"]
if path.exists():
if use_cache:
logger.info(
Expand All @@ -150,15 +153,15 @@ def get_cutouts(

try:
download_cutout(
result["cutout_url"],
comp_result["cutout_url"],
out_file=path.as_posix(),
cache=True,
pkgname="cutouts",
timeout=timeout,
)
except Exception as e:
logger.warning(e)
result["error"] = str(e)
comp_result["error"] = str(e)

return results, comparison_results

Expand Down Expand Up @@ -275,8 +278,8 @@ def run_cutouts_from_precovery(
)

plot_candidates = []
for i, result in enumerate(cutout_results):
if "error" in result:
for i, cutout_result in enumerate(cutout_results):
if "error" in cutout_result:
candidate = {
"path": None,
"ra": observations["pred_ra_deg"].values[i],
Expand All @@ -295,7 +298,7 @@ def run_cutouts_from_precovery(
}
else:
candidate = {
"path": result["cutout_image_path"],
"path": cutout_result["cutout_image_path"],
"ra": observations["pred_ra_deg"].values[i],
"dec": observations["pred_dec_deg"].values[i],
"vra": observations["pred_vra_degpday"].values[i],
Expand All @@ -306,9 +309,9 @@ def run_cutouts_from_precovery(
"mag_sigma": observations["mag_sigma"].values[i],
"filter": observations["filter"].values[i],
"obscode": observations["obscode"].values[i],
"exposure_start": result["exposure_start_mjd"],
"exposure_duration": result["exposure_duration"],
"exposure_id": result["exposure_id"],
"exposure_start": cutout_result["exposure_start_mjd"],
"exposure_duration": cutout_result["exposure_duration"],
"exposure_id": cutout_result["exposure_id"],
}
plot_candidates.append(candidate)

Expand All @@ -323,8 +326,8 @@ def run_cutouts_from_precovery(

if compare:
plot_comparison_candidates = []
for i, result in enumerate(comparison_results):
if "error" in result:
for i, comparison_result in enumerate(comparison_results):
if "error" in comparison_result:
candidate = {
"path": None,
"ra": observations["pred_ra_deg"].values[i],
Expand All @@ -341,18 +344,18 @@ def run_cutouts_from_precovery(
}
else:
candidate = {
"path": result["cutout_image_path"],
"path": comparison_result["cutout_image_path"],
"ra": observations["pred_ra_deg"].values[i],
"dec": observations["pred_dec_deg"].values[i],
"vra": np.NaN,
"vdec": np.NaN,
"mag": np.NaN,
"mag_sigma": np.NaN,
"filter": result["filter"],
"filter": comparison_result["filter"],
"obscode": observations["obscode"].values[i],
"exposure_start": result["exposure_start_mjd"],
"exposure_duration": result["exposure_duration"],
"exposure_id": result["exposure_id"],
"exposure_start": comparison_result["exposure_start_mjd"],
"exposure_duration": comparison_result["exposure_duration"],
"exposure_id": comparison_result["exposure_id"],
}
plot_comparison_candidates.append(candidate)

Expand Down

0 comments on commit 12b8650

Please sign in to comment.