diff --git a/smsdk/client.py b/smsdk/client.py index 943fd6e..4c51d8e 100644 --- a/smsdk/client.py +++ b/smsdk/client.py @@ -654,8 +654,6 @@ def get_machine_names(self, source_type=None, clean_strings_out=True): query_params["source_type"] = source_type machines = self.get_data_v1("machine_v1", "get_machines", True, **query_params) - # machine api endpoint doesn't seem to accept query params - machines = machines[machines["source_type"] == source_type] if clean_strings_out: return machines["source_clean"].to_list() diff --git a/smsdk/smsdk_entities/machine/machine_v1.py b/smsdk/smsdk_entities/machine/machine_v1.py index 98c6382..b48eb14 100644 --- a/smsdk/smsdk_entities/machine/machine_v1.py +++ b/smsdk/smsdk_entities/machine/machine_v1.py @@ -12,6 +12,7 @@ from smsdk.utils import module_utility from smsdk import config from smsdk.ma_session import MaSession +from urllib.parse import urlencode, urlunparse ENDPOINTS = json.loads(pkg_resources.read_text(config, "api_endpoints.json")) @@ -42,10 +43,10 @@ def get_machines(self, *args, **kwargs): Recommend to use 'enable_pagination':True for larger datasets """ url = "{}{}".format(self.base_url, ENDPOINTS["Machine"]["url_v1"]) + url = self.modify_query_params(url, kwargs) records = self._get_records_v1( url, method="get", results_under="objects", **kwargs ) - # records = self._get_records_v1(url, method="get", **kwargs) if not isinstance(records, List): raise ValueError("Error - {}".format(records)) return records @@ -64,3 +65,48 @@ def get_type_from_machine_name(self, machine_source, *args, **kwargs): ): machine_type = record["type"] return machine_type + + def modify_query_params(self, url, kwargs): + where = [] + order_by = [] + params = {} + for key, value in kwargs.items(): + where_query = {} + orderby_query = {} + select_query = {} + if not key.startswith("_"): + where_query["name"] = key + where_query["value"] = value + where.append(where_query) + elif key == "_order_by": + orderby_query["name"] = ( + value.replace("-", "") if value.startswith("-") else value + ) + orderby_query["order"] = "desc" if value.startswith("-") else "asc" + order_by.append(orderby_query) + + """ + Other keys we are ignoring from above loop. + _only and _limit are handled in below implementatioin + """ + + # Using eval as kwargs.get("_only") gives list in string format + select = [{"name": i} for i in eval(kwargs.get("_only", "[]"))] + limit = kwargs.get("_limit", None) + if len(where) > 0: + where = json.dumps(where, ensure_ascii=False) + params["where"] = where + if len(order_by) > 0: + order_by = json.dumps(order_by, ensure_ascii=False) + params["order_by"] = order_by + if len(select) > 0: + select = json.dumps(select, ensure_ascii=False) + params["select"] = select + if limit: + params["limit"] = limit + + if params: + encoded_params = urlencode(params) + url = urlunparse(("", "", url, "", encoded_params, "")) + + return url diff --git a/tests/downtime/test_downtime.py b/tests/downtime/test_downtime.py index 872c867..d5caa62 100644 --- a/tests/downtime/test_downtime.py +++ b/tests/downtime/test_downtime.py @@ -10,7 +10,7 @@ MACHINE_INDEX = 0 START_DATETIME = datetime(2023, 4, 1) END_DATETIME = datetime(2023, 4, 2) -EXPECTED_ROWS = 15 +EXPECTED_ROWS = 18 EXPECTED_COL = 8 URL_V1 = "/v1/datatab/downtime" diff --git a/tests/machine/test_machine.py b/tests/machine/test_machine.py index f318aaf..2dd1c6d 100644 --- a/tests/machine/test_machine.py +++ b/tests/machine/test_machine.py @@ -136,3 +136,19 @@ def test_get_machine_schema_types_return_mtype(mocked_types, mocked_machines): def test_get_machines_v1(get_client): machines = get_client.get_machines() assert machines.shape == (49, 10) + + +def test_get_machines_with_query_params(get_client): + limit = 20 + query_params = { + "_only": ["source", "source_clean", "source_type"], + "source_type": "Lasercut", + "_order_by": "source_clean", + "_limit": limit, + } + machines = get_client.get_machines(**query_params) + + assert len(machines) == limit + + # Checking that we should only get these three columns that we have provided on query params. + assert machines.columns.tolist() == ["source", "source_clean", "source_type"]