Skip to content

Commit

Permalink
[releases/2.1] Cherry-pick: [High priority] Measurement Client: Raise…
Browse files Browse the repository at this point in the history
… errors for invalid enums (#920)

[High priority] Measurement Client: Raise errors for invalid enums (#916)

* Fix: Raise errors for invalid enums

* Refactor: Method names for clarity

* Fix: Transform invalid enum values

* Tests: Add unit tests for client

* Revert: main.py change

* Fix: Refine regex in support.py

* Tests: Update expected and actual assert positioning

(cherry picked from commit 4a4c464)
  • Loading branch information
MounikaBattu17 authored Sep 24, 2024
1 parent 66d21db commit c0d41be
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
create_class_name,
create_module_name,
extract_base_service_class,
get_configuration_metadata_by_index,
get_configuration_and_output_metadata_by_index,
get_configuration_parameters_with_type_and_default_values,
get_measurement_service_stub,
get_output_metadata_by_index,
get_output_parameters_with_type,
get_all_registered_measurement_info,
get_selected_measurement_service_class,
Expand Down Expand Up @@ -70,10 +69,9 @@ def _create_client(
discovery_client, channel_pool, measurement_service_class
)
metadata = measurement_service_stub.GetMetadata(v2_measurement_service_pb2.GetMetadataRequest())
configuration_metadata = get_configuration_metadata_by_index(
configuration_metadata, output_metadata = get_configuration_and_output_metadata_by_index(
metadata, measurement_service_class, enum_values_by_type
)
output_metadata = get_output_metadata_by_index(metadata, enum_values_by_type)

configuration_parameters_with_type_and_default_values, measure_api_parameters = (
get_configuration_parameters_with_type_and_default_values(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import sys
from enum import Enum
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
from typing import AbstractSet, Dict, Iterable, List, Optional, Tuple, Type, TypeVar

import click
import grpc
Expand Down Expand Up @@ -101,24 +101,31 @@ def get_all_registered_measurement_info(
return measurement_service_classes, measurement_display_names


def get_configuration_metadata_by_index(
def get_configuration_and_output_metadata_by_index(
metadata: v2_measurement_service_pb2.GetMetadataResponse,
service_class: str,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Dict[int, ParameterMetadata]:
"""Returns the configuration metadata of the measurement."""
) -> Tuple[Dict[int, ParameterMetadata], Dict[int, ParameterMetadata]]:
"""Returns the configuration and output metadata of the measurement."""
configuration_parameter_list = []
for configuration in metadata.measurement_signature.configuration_parameters:
annotations_dict = dict(configuration.annotations.items())
if _is_enum_param(configuration.type):
annotations_dict["ni/enum.values"] = _validate_and_transform_enum_annotations(
configuration.annotations["ni/enum.values"]
)
configuration_parameter_list.append(
ParameterMetadata.initialize(
display_name=configuration.name,
type=configuration.type,
repeated=configuration.repeated,
default_value=None,
annotations=dict(configuration.annotations.items()),
annotations=annotations_dict,
message_type=configuration.message_type,
enum_type=(
_get_enum_type(configuration, enum_values_by_type)
_get_enum_type(
configuration.name, annotations_dict["ni/enum.values"], enum_values_by_type
)
if _is_enum_param(configuration.type)
else None
),
Expand All @@ -127,16 +134,23 @@ def get_configuration_metadata_by_index(

output_parameter_list = []
for output in metadata.measurement_signature.outputs:
annotations_dict = dict(output.annotations.items())
if _is_enum_param(output.type):
annotations_dict["ni/enum.values"] = _validate_and_transform_enum_annotations(
output.annotations["ni/enum.values"]
)
output_parameter_list.append(
ParameterMetadata.initialize(
display_name=output.name,
type=output.type,
repeated=output.repeated,
default_value=None,
annotations=dict(output.annotations.items()),
annotations=annotations_dict,
message_type=output.message_type,
enum_type=(
_get_enum_type(output, enum_values_by_type)
_get_enum_type(
output.name, annotations_dict["ni/enum.values"], enum_values_by_type
)
if _is_enum_param(output.type)
else None
),
Expand All @@ -150,6 +164,7 @@ def get_configuration_metadata_by_index(
pool=descriptor_pool.Default(),
)
configuration_metadata = frame_metadata_dict(configuration_parameter_list)
output_metadata = frame_metadata_dict(output_parameter_list)
deserialized_parameters = deserialize_parameters(
configuration_metadata,
metadata.measurement_signature.configuration_defaults.value,
Expand All @@ -166,33 +181,7 @@ def get_configuration_metadata_by_index(

configuration_metadata[k] = configuration_metadata[k]._replace(default_value=default_value)

return configuration_metadata


def get_output_metadata_by_index(
metadata: v2_measurement_service_pb2.GetMetadataResponse,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Dict[int, ParameterMetadata]:
"""Returns the output metadata of the measurement."""
output_parameter_list = []
for output in metadata.measurement_signature.outputs:
output_parameter_list.append(
ParameterMetadata.initialize(
display_name=output.name,
type=output.type,
repeated=output.repeated,
default_value=None,
annotations=dict(output.annotations.items()),
message_type=output.message_type,
enum_type=(
_get_enum_type(output, enum_values_by_type)
if _is_enum_param(output.type)
else None
),
)
)
output_metadata = frame_metadata_dict(output_parameter_list)
return output_metadata
return configuration_metadata, output_metadata


def get_configuration_parameters_with_type_and_default_values(
Expand Down Expand Up @@ -229,7 +218,9 @@ def get_configuration_parameters_with_type_and_default_values(
)

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type = _get_enum_type(metadata, enum_values_by_type)
enum_type = _get_enum_type(
metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type
)
parameter_type = enum_type.__name__
if metadata.repeated:
values = []
Expand Down Expand Up @@ -280,7 +271,9 @@ def get_output_parameters_with_type(
parameter_type = f"List[{parameter_type}]"

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type_name = _get_enum_type(metadata, enum_values_by_type).__name__
enum_type_name = _get_enum_type(
metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type
).__name__
parameter_type = f"List[{enum_type_name}]" if metadata.repeated else enum_type_name

output_parameters_with_type.append(f"{parameter_name}: {parameter_type}")
Expand Down Expand Up @@ -410,16 +403,16 @@ def _is_enum_param(parameter_type: int) -> bool:


def _get_enum_type(
parameter: Any, enum_values_by_type: Dict[Type[Enum], Dict[str, int]]
parameter_name: str,
enum_annotations: str,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]],
) -> Type[Enum]:
loaded_enum_values = json.loads(parameter.annotations["ni/enum.values"])
enum_values = {key: value for key, value in loaded_enum_values.items()}

enum_values = dict(json.loads(enum_annotations))
for existing_enum_type, existing_enum_values in enum_values_by_type.items():
if existing_enum_values == enum_values:
return existing_enum_type

new_enum_type_name = _get_enum_class_name(parameter.name)
new_enum_type_name = _get_enum_class_name(parameter_name)
# MyPy error: Enum() expects a string literal as the first argument.
# Ignoring this error because MyPy cannot validate dynamic Enum creation statically.
new_enum_type = Enum(new_enum_type_name, enum_values) # type: ignore[misc]
Expand All @@ -435,3 +428,22 @@ def _get_enum_class_name(name: str) -> str:
else:
name = name[0].upper() + name[1:]
return f"{name}Enum"


def _validate_and_transform_enum_annotations(enum_annotations: str) -> str:
enum_values = dict(json.loads(enum_annotations))
transformed_enum_annotations = {}
for enum_value, value in enum_values.items():
original_enum_value = enum_value

enum_value = re.sub(r"\W+", "_", enum_value)
if enum_value[0].isdigit():
enum_value = f"k_{enum_value}"

# Check for enum values that are only special characters.
if not enum_value.strip("_"):
raise click.ClickException(f"The enum value '{original_enum_value}' is invalid.")

transformed_enum_annotations[enum_value] = value

return json.dumps(transformed_enum_annotations)
36 changes: 36 additions & 0 deletions packages/generator/tests/unit/test_measurement_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
from click import ClickException

from ni_measurement_plugin_sdk_generator.client._support import (
_validate_and_transform_enum_annotations,
)


@pytest.mark.parametrize(
"enum_annotations, expected_enum_annotations",
[
('{"NONE": 0, "RED": 1, "GREEN": 2}', '{"NONE": 0, "RED": 1, "GREEN": 2}'),
(
'{"DC Volts": 0, "2-Wire Resistance": 1, "5 1/2": 2}',
'{"DC_Volts": 0, "k_2_Wire_Resistance": 1, "k_5_1_2": 2}',
),
],
)
def test___enum_annotations___validate_and_transform_enum_annotations___returns_expected_enum_annotations(
enum_annotations: str, expected_enum_annotations: str
) -> None:
actual_enum_annotations = _validate_and_transform_enum_annotations(enum_annotations)

assert actual_enum_annotations == expected_enum_annotations


def test___invalid_enum_annotations___validate_and_transform_enum_annotations___raises_invalid_enum_value_error() -> (
None
):
enum_annotations = '{"DC Volts": 0, "*": 1}'
expected_error_message = "The enum value '*' is invalid."

with pytest.raises(ClickException) as exc_info:
_ = _validate_and_transform_enum_annotations(enum_annotations)

assert exc_info.value.message == expected_error_message

0 comments on commit c0d41be

Please sign in to comment.