Skip to content

Commit

Permalink
Merge pull request #24 from B612-Asteroid-Institute/nt/remove_pandera
Browse files Browse the repository at this point in the history
After running bump-pydantic
  • Loading branch information
ntellis authored Jan 26, 2024
2 parents 2416fc4 + 12b8650 commit 9cf6b1d
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 150 deletions.
11 changes: 4 additions & 7 deletions cutouts/filter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import logging

import numpy as np
import pandera as pa
from pandera.typing import DataFrame
import pandas as pd

from .io.types import CutoutRequest, CutoutResult, CutoutsResultSchema
from .io.types import CutoutRequest, CutoutResult

logger = logging.getLogger(__file__)


@pa.check_types
def select_cutout(
results_df: DataFrame[CutoutsResultSchema], cutout_request: CutoutRequest
results_df: pd.DataFrame, cutout_request: CutoutRequest
) -> CutoutResult:
"""
Select the cutout closest to the requested exposure start time +- delta_time.
Expand Down Expand Up @@ -91,9 +89,8 @@ def select_cutout(
return result


@pa.check_types
def select_comparison_cutout(
results_df: DataFrame[CutoutsResultSchema],
results_df: pd.DataFrame,
cutout_result: CutoutResult,
cutout_request: CutoutRequest,
min_time_separation: float = 1 / 24,
Expand Down
8 changes: 3 additions & 5 deletions cutouts/io/nsc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging

import pandera as pa
from pandera.typing import DataFrame
import pandas as pd
from pyvo.dal.sia import SIAResults

from .sia import SIAHandler
from .types import CutoutRequest, CutoutsResultSchema
from .types import CutoutRequest
from .util import exposure_id_from_url

logger = logging.getLogger(__name__)
Expand All @@ -16,10 +15,9 @@ def _get_generic_image_url_from_cutout_url(cutout_url: str):
return cutout_url.split("&POS=")[0]


@pa.check_types
def find_cutouts_nsc_dr2(
cutout_request: CutoutRequest,
) -> DataFrame[CutoutsResultSchema]:
) -> pd.DataFrame:
"""
Search the NOIRLab Archive for cutouts and images at a given RA, Dec.
Expand Down
7 changes: 2 additions & 5 deletions cutouts/io/skymapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging

import pandas as pd
import pandera as pa
from pandera.typing import DataFrame
from pyvo.dal.sia import SIAResults

from .sia import SIAHandler
from .types import CutoutRequest, CutoutsResultSchema
from .types import CutoutRequest

logger = logging.getLogger(__name__)

Expand All @@ -22,10 +20,9 @@ def _get_generic_image_url_from_cutout_url(cutout_url: str):
return url_string


@pa.check_types
def find_cutouts_skymapper_dr2(
cutout_request: CutoutRequest,
) -> DataFrame[CutoutsResultSchema]:
) -> pd.DataFrame:
"""
Search the Skymapper SIA service for cutouts and images at a given RA, Dec.
Expand Down
25 changes: 0 additions & 25 deletions cutouts/io/tests/test_nsc_dr2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
from contextlib import contextmanager
from unittest.mock import patch
import pathlib
import pickle
import pandas as pd

from astropy.io.votable import parse
from pyvo.dal.sia import SIAResults

from ..nsc import NSC_DR2_SIA, find_cutouts_nsc_dr2
from ...main import get_cutouts
from ..types import CutoutRequest

# 2014 HE199 (2014-04-28T08:07:52.435)
Expand Down Expand Up @@ -44,7 +40,6 @@ def mock_sia_nsc_dr2_query(table_file: str):


def test_sia_nsc_dr2_query():

with mock_sia_nsc_dr2_query(
"nsc_dr2_227.5251615214173_-27.026013823449265_56775.33880132809.xml"
) as mock:
Expand All @@ -71,23 +66,3 @@ def test_sia_nsc_dr2_query():
assert col in results.columns

assert "VR" in results["filter"].values


def test_sia_nsc_dr2_query_():

with mock_sia_nsc_dr2_query(
"nsc_dr2_227.5251615214173_-27.026013823449265_56775.33880132809.xml"
) as mock:
results, comparison_results = get_cutouts(
pd.DataFrame(cutout_request1),
out_dir=pathlib.Path(os.path.dirname(os.path.abspath(__file__)))
)
print(results)
print(comparison_results)

with open('cutout_results.pickle', 'wb') as f:
pickle.dump(results, f)
with open('cutout_comparison_results.pickle', 'wb') as f:
pickle.dump(comparison_results, f)

assert False
60 changes: 9 additions & 51 deletions cutouts/io/types.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,21 @@
from typing import Optional

import pandera as pa
from pandera.typing import Series
from pydantic import BaseModel


class CutoutRequestSchema(pa.SchemaModel):
"""
Dataframe validation for multiple cutout requests
"""

request_id: Optional[Series[str]] = pa.Field(nullable=True)
observatory_code: Series[str] = pa.Field(coerce=True)
exposure_start_mjd: Series[float] = pa.Field(nullable=False, coerce=True)
ra_deg: Series[float] = pa.Field(ge=0, le=360, coerce=True)
dec_deg: Series[float] = pa.Field(ge=-90, le=90, coerce=True)
filter: Optional[
Series[str]
] = pa.Field() # TODO: validate against real list of filters?
exposure_id: Optional[Series[str]] = pa.Field(nullable=True)
exposure_duration: Optional[Series[float]] = pa.Field(
ge=0, le=2000, coerce=True, nullable=True
)
height_arcsec: Series[float] = pa.Field(ge=0, le=200, coerce=True, nullable=True)
width_arcsec: Series[float] = pa.Field(ge=0, le=200, coerce=True, nullable=True)
delta_time: Series[float] = pa.Field(ge=0, le=100, coerce=True, nullable=True)


class CutoutsResultSchema(pa.SchemaModel):
# TODO: ra, dec here are the ra, dec returned by the query and
# these should be equal to the queried ra and dec.
# However, this may not always be true and we may want to consider
# adding additional fields to allow backends to return things such as the
# the center of the image/cutout.
ra_deg: Series[float] = pa.Field(ge=0, le=360, coerce=True)
dec_deg: Series[float] = pa.Field(ge=-90, le=90, coerce=True)
filter: Series[str] = pa.Field()
exposure_id: Series[str] = pa.Field()
exposure_start_mjd: Series[float] = pa.Field(nullable=False, coerce=True)
exposure_duration: Series[float] = pa.Field(ge=0, le=2000, coerce=True)
cutout_url: Series[str] = pa.Field(coerce=True)
image_url: Series[str] = pa.Field(coerce=True)
height_arcsec: Series[float] = pa.Field(ge=0, le=200, coerce=True)
width_arcsec: Series[float] = pa.Field(ge=0, le=200, coerce=True)


class CutoutRequest(BaseModel):
"""
A single cutout request
"""

request_id: Optional[str]
request_id: Optional[str] = None
observatory_code: str
exposure_start_mjd: float
ra_deg: float
dec_deg: float
filter: Optional[str]
exposure_id: Optional[str]
exposure_duration: Optional[float]
filter: Optional[str] = None
exposure_id: Optional[str] = None
exposure_duration: Optional[float] = None
height_arcsec: float
width_arcsec: float
delta_time: float
Expand All @@ -69,13 +27,13 @@ class CutoutResult(BaseModel):
"""

cutout_url: str
dec_deg: Optional[float]
dec_deg: Optional[float] = None
exposure_duration: float
exposure_id: Optional[str]
exposure_id: Optional[str] = None
exposure_start_mjd: float
filter: Optional[str]
filter: Optional[str] = None
height_arcsec: float
image_url: str
ra_deg: Optional[float]
request_id: Optional[str]
ra_deg: Optional[float] = None
request_id: Optional[str] = None
width_arcsec: float
7 changes: 2 additions & 5 deletions cutouts/io/ztf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import backoff
import numpy as np
import pandas as pd
import pandera as pa
import requests
from pandera.typing import DataFrame

from .types import CutoutRequest, CutoutsResultSchema
from .types import CutoutRequest

logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -83,8 +81,7 @@ def perform_request(search_url):
return response.text


@pa.check_types()
def find_cutouts_ztf(cutout_request: CutoutRequest) -> DataFrame[CutoutsResultSchema]:
def find_cutouts_ztf(cutout_request: CutoutRequest) -> pd.DataFrame:
"""
Search the ZTF service for cutouts and images at a given RA, Dec.
Expand Down
Loading

0 comments on commit 9cf6b1d

Please sign in to comment.