Skip to content

Commit

Permalink
Chore: Make release 1.0.88
Browse files Browse the repository at this point in the history
  • Loading branch information
martinroberson committed Jun 25, 2024
1 parent 2e7fa9a commit 7e946ad
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from gs_quant.json_convertors import decode_optional_date
from gs_quant.priceable import PriceableImpl
from gs_quant.target.backtests import BacktestTradingQuantityType, EquityMarketModel
from gs_quant.target.common import PricingLocation
from gs_quant.common import PricingLocation, Currency


class TransactionCostModel(Enum):
Expand All @@ -40,6 +40,7 @@ class Transaction:
portfolio: Tuple[Instrument, ...]
portfolio_price: Optional[float] = None
cost: Optional[float] = None
currency: Optional[Currency] = None


@dataclass_json(letter_case=LetterCase.CAMEL)
Expand Down
79 changes: 69 additions & 10 deletions gs_quant/api/gs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
specific language governing permissions and limitations
under the License.
"""
import asyncio
import datetime as dt
import json
import logging
import time
from copy import copy
from copy import copy, deepcopy
from enum import Enum
from itertools import chain
from typing import Iterable, List, Optional, Tuple, Union, Dict
Expand Down Expand Up @@ -178,16 +179,31 @@ def set_api_request_cache(cls, cache: ApiRequestCache):
cls._api_request_cache = cache

@classmethod
def _post_with_cache_check(cls, url, **kwargs):
def _check_cache(cls, url, **kwargs):
session = cls.get_session()
cache_key = None
cached_val = None
if cls._api_request_cache:
cache_key = (url, 'POST', kwargs)
cached_val = cls._api_request_cache.get(session, cache_key)
if cached_val is not None:
return cached_val
result = session._post(url, **kwargs)
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return cached_val, cache_key, session

@classmethod
def _post_with_cache_check(cls, url, **kwargs):
result, cache_key, session = cls._check_cache(url, **kwargs)
if result is None:
result = session._post(url, **kwargs)
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result

@classmethod
async def _post_with_cache_check_async(cls, url, **kwargs):
result, cache_key, session = cls._check_cache(url, **kwargs)
if result is None:
result = await session._post_async(url, **kwargs)
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result

@classmethod
Expand All @@ -204,16 +220,37 @@ def query_data(cls, query: Union[DataQuery, MDAPIDataQuery], dataset_id: str = N
response: Union[DataQueryResponse, dict] = cls.execute_query(dataset_id, query)
return cls.get_results(dataset_id, response, query)

@classmethod
async def query_data_async(cls, query: Union[DataQuery, MDAPIDataQuery], dataset_id: str = None) \
-> Union[MDAPIDataBatchResponse, DataQueryResponse, tuple, list]:
if isinstance(query, MDAPIDataQuery) and query.market_data_coordinates:
# Don't use MDAPIDataBatchResponse for now - it doesn't handle quoting style correctly
results: Union[MDAPIDataBatchResponse, dict] = await cls.execute_query_async('coordinates', query)
if isinstance(results, dict):
return results.get('responses', ())
else:
return results.responses if results.responses is not None else ()
response: Union[DataQueryResponse, dict] = await cls.execute_query_async(dataset_id, query)
results = await cls.get_results_async(dataset_id, response, query)
return results

@classmethod
def execute_query(cls, dataset_id: str, query: Union[DataQuery, MDAPIDataQuery]):
kwargs = {'payload': query}
if getattr(query, 'format', None) in (Format.MessagePack, 'MessagePack'):
kwargs['request_headers'] = {'Accept': 'application/msgpack'}
return cls._post_with_cache_check('/data/{}/query'.format(dataset_id), **kwargs)

@classmethod
async def execute_query_async(cls, dataset_id: str, query: Union[DataQuery, MDAPIDataQuery]):
kwargs = {'payload': query}
if getattr(query, 'format', None) in (Format.MessagePack, 'MessagePack'):
kwargs['request_headers'] = {'Accept': 'application/msgpack'}
result = await cls._post_with_cache_check_async('/data/{}/query'.format(dataset_id), **kwargs)
return result

@staticmethod
def get_results(dataset_id: str, response: Union[DataQueryResponse, dict], query: DataQuery) -> \
Union[list, Tuple[list, list]]:
def _get_results(response: Union[DataQueryResponse, dict]):
if isinstance(response, dict):
total_pages = response.get('totalPages')
results = response.get('data', [])
Expand All @@ -228,7 +265,12 @@ def get_results(dataset_id: str, response: Union[DataQueryResponse, dict], query
else:
total_pages = response.total_pages if response.total_pages is not None else 0
results = response.data if response.data is not None else ()
return results, total_pages

@staticmethod
def get_results(dataset_id: str, response: Union[DataQueryResponse, dict], query: DataQuery) -> \
Union[list, Tuple[list, list]]:
results, total_pages = GsDataApi._get_results(response)
if total_pages:
if query.page is None:
query.page = total_pages - 1
Expand All @@ -238,7 +280,21 @@ def get_results(dataset_id: str, response: Union[DataQueryResponse, dict], query
results = results + GsDataApi.get_results(dataset_id, GsDataApi.execute_query(dataset_id, query), query)
else:
return results
return results

@staticmethod
async def get_results_async(dataset_id: str, response: Union[DataQueryResponse, dict], query: DataQuery) -> \
Union[list, Tuple[list, list]]:
results, total_pages = GsDataApi._get_results(response)
if total_pages and total_pages > 1:
futures = []
for page in range(1, total_pages):
query = deepcopy(query)
query.page = page
futures.append(GsDataApi.execute_query_async(dataset_id, query))
all_responses = await asyncio.gather(*futures, return_exceptions=True)
for response_crt in all_responses:
results += GsDataApi._get_results(response_crt)[0]
return results

@classmethod
Expand Down Expand Up @@ -298,7 +354,10 @@ def get_coverage(
results = scroll_results = body['results']
total_results = body['totalResults']
while len(scroll_results) and len(results) < total_results:
params['scrollId'] = body['scrollId']
scroll_id = body.get('scrollId')
if scroll_id is None:
break
params['scrollId'] = scroll_id
body = session._get(f'/data/{dataset_id}/coverage', payload=params)
scroll_results = body['results']
results += scroll_results
Expand Down
167 changes: 125 additions & 42 deletions gs_quant/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ def provider(self):
from gs_quant.api.gs.data import GsDataApi
return self.__provider or GsDataApi

def _build_data_query(
self, start: Union[dt.date, dt.datetime], end: Union[dt.date, dt.datetime], as_of: dt.datetime,
since: dt.datetime, fields: Iterable[Union[str, Fields]], empty_intervals: bool, **kwargs):
field_names = None if fields is None else list(map(lambda f: f if isinstance(f, str) else f.value, fields))
# check whether a function is called e.g. difference(tradePrice)
schema_varies = field_names is not None and any(map(lambda s: re.match("\\w+\\(", s), field_names))
if kwargs and "date" in kwargs:
d = kwargs["date"]
if type(d) is str:
try:
kwargs["date"] = dt.datetime.strptime(d, "%Y-%m-%d").date()
except ValueError:
pass # Ignore error if date parameter is in some other format
if "dates" not in kwargs and start is None and end is None:
kwargs["dates"] = (kwargs["date"],)
return self.provider.build_query(start=start, end=end, as_of=as_of, since=since, fields=field_names,
empty_intervals=empty_intervals, **kwargs), schema_varies

def _build_data_frame(self, data, schema_varies, standard_fields) -> pd.DataFrame:
if type(data) is tuple:
df = self.provider.construct_dataframe_with_types(self.id, data[0], schema_varies,
standard_fields=standard_fields)
return df.groupby(data[1], group_keys=True).apply(lambda x: x)
else:
return self.provider.construct_dataframe_with_types(self.id, data, schema_varies,
standard_fields=standard_fields)

def get_data(
self,
start: Optional[Union[dt.date, dt.datetime]] = None,
Expand Down Expand Up @@ -141,71 +168,52 @@ def get_data(
>>> weather_data = weather.get_data(dt.date(2016, 1, 15), dt.date(2016, 1, 16), city=('Boston', 'Austin'))
"""

field_names = None if fields is None else list(map(lambda f: f if isinstance(f, str) else f.value, fields))
# check whether a function is called e.g. difference(tradePrice)
schema_varies = field_names is not None and any(map(lambda s: re.match("\\w+\\(", s), field_names))
if kwargs and "date" in kwargs:
d = kwargs["date"]
if type(d) is str:
try:
kwargs["date"] = dt.datetime.strptime(d, "%Y-%m-%d").date()
except ValueError:
pass # Ignore error if date parameter is in some other format
if "dates" not in kwargs and start is None and end is None:
kwargs["dates"] = (kwargs["date"],)
query = self.provider.build_query(
start=start,
end=end,
as_of=as_of,
since=since,
fields=field_names,
empty_intervals=empty_intervals,
**kwargs
)
query, schema_varies = self._build_data_query(start, end, as_of, since, fields, empty_intervals, **kwargs)
data = self.provider.query_data(query, self.id, asset_id_type=asset_id_type)
if type(data) is tuple:
df = self.provider.construct_dataframe_with_types(self.id, data[0], schema_varies,
standard_fields=standard_fields)
return df.groupby(data[1], group_keys=True).apply(lambda x: x)
else:
return self.provider.construct_dataframe_with_types(self.id, data, schema_varies,
standard_fields=standard_fields)
return self._build_data_frame(data, schema_varies, standard_fields)

def get_data_series(
async def get_data_async(
self,
field: Union[str, Fields],
start: Optional[Union[dt.date, dt.datetime]] = None,
end: Optional[Union[dt.date, dt.datetime]] = None,
as_of: Optional[dt.datetime] = None,
since: Optional[dt.datetime] = None,
dates: Optional[List[dt.date]] = None,
fields: Optional[Iterable[Union[str, Fields]]] = None,
empty_intervals: Optional[bool] = None,
standard_fields: Optional[bool] = False,
**kwargs
) -> pd.Series:
) -> pd.DataFrame:
"""
Get a time series of data for a field of a dataset
Get data for the given range and parameters
:param field: The DataSet field to use
:param start: Requested start date/datetime for data
:param end: Requested end date/datetime for data
:param as_of: Request data as_of
:param since: Request data since
:param fields: DataSet fields to include
:param empty_intervals: whether to request empty intervals
:param standard_fields: If set, will use fields api instead of catalog api to get fieldTypes
:param kwargs: Extra query arguments, e.g. ticker='EDZ19'
:return: A Series of the requested data, indexed by date or time, depending on the DataSet
:return: A Dataframe of the requested data
**Examples**
>>> from gs_quant.data import Dataset
>>> import datetime as dt
>>>
>>> weather = Dataset('WEATHER')
>>> dew_point = weather
>>>> .get_data_series('dewPoint', dt.date(2016, 1, 15), dt.date(2016, 1, 16), city=('Boston', 'Austin'))
>>> weather_data = await weather.get_data_async(dt.date(2016, 1, 15), dt.date(2016, 1, 16),
>>> city=('Boston', 'Austin'))
"""

field_value = field if isinstance(field, str) else field.value
query, schema_varies = self._build_data_query(start, end, as_of, since, fields, empty_intervals, **kwargs)
data = await self.provider.query_data_async(query, self.id)
return self._build_data_frame(data, schema_varies, standard_fields)

def _build_data_series_query(self, field: Union[str, Fields], start: Union[dt.date, dt.datetime],
end: Union[dt.date, dt.datetime], as_of: dt.datetime, since: dt.datetime,
dates: List[dt.date], **kwargs):
field_value = field if isinstance(field, str) else field.value
query = self.provider.build_query(
start=start,
end=end,
Expand All @@ -215,13 +223,13 @@ def get_data_series(
dates=dates,
**kwargs
)

symbol_dimensions = self.provider.symbol_dimensions(self.id)
if len(symbol_dimensions) != 1:
raise MqValueError('get_data_series only valid for symbol_dimensions of length 1')

symbol_dimension = symbol_dimensions[0]
data = self.provider.query_data(query, self.id)
return field_value, query, symbol_dimension

def _build_data_series(self, data, field_value, symbol_dimension, standard_fields: bool) -> pd.Series:
df = self.provider.construct_dataframe_with_types(self.id, data, standard_fields=standard_fields)

from gs_quant.api.gs.data import GsDataApi
Expand All @@ -230,14 +238,89 @@ def get_data_series(
gb = df.groupby(symbol_dimension)
if len(gb.groups) > 1:
raise MqValueError('Not a series for a single {}'.format(symbol_dimension))

if df.empty:
return pd.Series(dtype=float)
if '(' in field_value:
field_value = field_value.replace('(', '_')
field_value = field_value.replace(')', '')
return pd.Series(index=df.index, data=df.loc[:, field_value].values)

def get_data_series(
self,
field: Union[str, Fields],
start: Optional[Union[dt.date, dt.datetime]] = None,
end: Optional[Union[dt.date, dt.datetime]] = None,
as_of: Optional[dt.datetime] = None,
since: Optional[dt.datetime] = None,
dates: Optional[List[dt.date]] = None,
standard_fields: Optional[bool] = False,
**kwargs
) -> pd.Series:
"""
Get a time series of data for a field of a dataset
:param field: The DataSet field to use
:param start: Requested start date/datetime for data
:param end: Requested end date/datetime for data
:param as_of: Request data as_of
:param since: Request data since
:param dates: Requested dates for data
:param standard_fields: If set, will use fields api instead of catalog api to get fieldTypes
:param kwargs: Extra query arguments, e.g. ticker='EDZ19'
:return: A Series of the requested data, indexed by date or time, depending on the DataSet
**Examples**
>>> from gs_quant.data import Dataset
>>> import datetime as dt
>>>
>>> weather = Dataset('WEATHER')
>>> dew_point = weather
>>>> .get_data_series('dewPoint', dt.date(2016, 1, 15), dt.date(2016, 1, 16), city=('Boston', 'Austin'))
"""
field_value, query, symbol_dimension = self._build_data_series_query(field, start, end, as_of, since, dates,
**kwargs)
data = self.provider.query_data(query, self.id)
return self._build_data_series(data, field_value, symbol_dimension, standard_fields)

async def get_data_series_async(
self,
field: Union[str, Fields],
start: Optional[Union[dt.date, dt.datetime]] = None,
end: Optional[Union[dt.date, dt.datetime]] = None,
as_of: Optional[dt.datetime] = None,
since: Optional[dt.datetime] = None,
dates: Optional[List[dt.date]] = None,
standard_fields: Optional[bool] = False,
**kwargs
) -> pd.Series:
"""
Get a time series of data for a field of a dataset
:param field: The DataSet field to use
:param start: Requested start date/datetime for data
:param end: Requested end date/datetime for data
:param as_of: Request data as_of
:param since: Request data since
:param dates: Requested dates for data
:param standard_fields: If set, will use fields api instead of catalog api to get fieldTypes
:param kwargs: Extra query arguments, e.g. ticker='EDZ19'
:return: A Series of the requested data, indexed by date or time, depending on the DataSet
**Examples**
>>> from gs_quant.data import Dataset
>>> import datetime as dt
>>>
>>> weather = Dataset('WEATHER')
>>> dew_point = await weather.get_data_series_async('dewPoint', dt.date(2016, 1, 15), dt.date(2016, 1, 16),
>>> city=('Boston', 'Austin'))
"""
field_value, query, symbol_dimension = self._build_data_series_query(field, start, end, as_of, since, dates,
**kwargs)
data = await self.provider.query_data_async(query, self.id)
return self._build_data_series(data, field_value, symbol_dimension, standard_fields)

def get_data_last(
self,
as_of: Optional[Union[dt.date, dt.datetime]],
Expand Down

0 comments on commit 7e946ad

Please sign in to comment.