From 12b86505f3db306d6bc6f72f978d1c8dc19606ee Mon Sep 17 00:00:00 2001 From: Nate Tellis Date: Fri, 26 Jan 2024 13:39:40 -0500 Subject: [PATCH] Make mypy happy --- cutouts/main.py | 63 ++++++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/cutouts/main.py b/cutouts/main.py index c9be92b..a8f6e43 100644 --- a/cutouts/main.py +++ b/cutouts/main.py @@ -2,7 +2,7 @@ import logging import pathlib import sys -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -34,7 +34,7 @@ def get_cutouts( download_full_image: bool = False, compare: bool = False, compare_kwargs: Optional[dict] = None, -) -> Tuple[Iterable[Dict[str, Any]], Iterable[Dict[str, Any]]]: +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ """ # Get urls and metadata for each cutout @@ -47,13 +47,13 @@ def get_cutouts( "same_filter": True, } - results = [] - comparison_results = [] + results: List[Dict[str, Any]] = [] + comparison_results: List[Dict[str, Any]] = [] for record in cutout_requests.to_dict(orient="records"): try: cutout_request = CutoutRequest(**record) # type: ignore results_df = find_cutouts(cutout_request) - result = select_cutout(results_df, cutout_request) + result = dict(select_cutout(results_df, cutout_request)) except Exception as e: logger.warning(e) @@ -61,21 +61,24 @@ def get_cutouts( if compare: try: - comparison_result = select_comparison_cutout( - results_df, result, cutout_request, **compare_kwargs + comparison_result = dict( + select_comparison_cutout( + results_df, + CutoutRequest(result), + cutout_request, + **compare_kwargs, + ) ) except Exception as e: logger.warning(e) comparison_result = {"error": e} - comparison_result = dict(comparison_result) comparison_results.append(comparison_result) else: - comparison_results.append(None) + comparison_results.append({"error": "Comparison cutout not requested"}) - result = dict(result) results.append(result) for result in results: @@ -128,19 +131,19 @@ def get_cutouts( # Download comparison cutouts if compare: - for result in comparison_results: - if result is not None: - if "error" in result: + for comp_result in comparison_results: + if comp_result is not None: + if "error" in comp_result: continue ( comparison_full_image_path, comparison_cutout_image_path, ) = generate_local_image_paths(result) - result["cutout_image_path"] = ( + comp_result["cutout_image_path"] = ( pathlib.Path(out_dir) / "comparison" / comparison_cutout_image_path ) - path = result["cutout_image_path"] + path = comp_result["cutout_image_path"] if path.exists(): if use_cache: logger.info( @@ -150,7 +153,7 @@ def get_cutouts( try: download_cutout( - result["cutout_url"], + comp_result["cutout_url"], out_file=path.as_posix(), cache=True, pkgname="cutouts", @@ -158,7 +161,7 @@ def get_cutouts( ) except Exception as e: logger.warning(e) - result["error"] = str(e) + comp_result["error"] = str(e) return results, comparison_results @@ -275,8 +278,8 @@ def run_cutouts_from_precovery( ) plot_candidates = [] - for i, result in enumerate(cutout_results): - if "error" in result: + for i, cutout_result in enumerate(cutout_results): + if "error" in cutout_result: candidate = { "path": None, "ra": observations["pred_ra_deg"].values[i], @@ -295,7 +298,7 @@ def run_cutouts_from_precovery( } else: candidate = { - "path": result["cutout_image_path"], + "path": cutout_result["cutout_image_path"], "ra": observations["pred_ra_deg"].values[i], "dec": observations["pred_dec_deg"].values[i], "vra": observations["pred_vra_degpday"].values[i], @@ -306,9 +309,9 @@ def run_cutouts_from_precovery( "mag_sigma": observations["mag_sigma"].values[i], "filter": observations["filter"].values[i], "obscode": observations["obscode"].values[i], - "exposure_start": result["exposure_start_mjd"], - "exposure_duration": result["exposure_duration"], - "exposure_id": result["exposure_id"], + "exposure_start": cutout_result["exposure_start_mjd"], + "exposure_duration": cutout_result["exposure_duration"], + "exposure_id": cutout_result["exposure_id"], } plot_candidates.append(candidate) @@ -323,8 +326,8 @@ def run_cutouts_from_precovery( if compare: plot_comparison_candidates = [] - for i, result in enumerate(comparison_results): - if "error" in result: + for i, comparison_result in enumerate(comparison_results): + if "error" in comparison_result: candidate = { "path": None, "ra": observations["pred_ra_deg"].values[i], @@ -341,18 +344,18 @@ def run_cutouts_from_precovery( } else: candidate = { - "path": result["cutout_image_path"], + "path": comparison_result["cutout_image_path"], "ra": observations["pred_ra_deg"].values[i], "dec": observations["pred_dec_deg"].values[i], "vra": np.NaN, "vdec": np.NaN, "mag": np.NaN, "mag_sigma": np.NaN, - "filter": result["filter"], + "filter": comparison_result["filter"], "obscode": observations["obscode"].values[i], - "exposure_start": result["exposure_start_mjd"], - "exposure_duration": result["exposure_duration"], - "exposure_id": result["exposure_id"], + "exposure_start": comparison_result["exposure_start_mjd"], + "exposure_duration": comparison_result["exposure_duration"], + "exposure_id": comparison_result["exposure_id"], } plot_comparison_candidates.append(candidate)