Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[releases/2.1] Cherry-pick: [High priority] Measurement Client: Raise errors for invalid enums #920

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
create_class_name,
create_module_name,
extract_base_service_class,
get_configuration_metadata_by_index,
get_configuration_and_output_metadata_by_index,
get_configuration_parameters_with_type_and_default_values,
get_measurement_service_stub,
get_output_metadata_by_index,
get_output_parameters_with_type,
get_all_registered_measurement_info,
get_selected_measurement_service_class,
Expand Down Expand Up @@ -70,10 +69,9 @@ def _create_client(
discovery_client, channel_pool, measurement_service_class
)
metadata = measurement_service_stub.GetMetadata(v2_measurement_service_pb2.GetMetadataRequest())
configuration_metadata = get_configuration_metadata_by_index(
configuration_metadata, output_metadata = get_configuration_and_output_metadata_by_index(
metadata, measurement_service_class, enum_values_by_type
)
output_metadata = get_output_metadata_by_index(metadata, enum_values_by_type)

configuration_parameters_with_type_and_default_values, measure_api_parameters = (
get_configuration_parameters_with_type_and_default_values(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import sys
from enum import Enum
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
from typing import AbstractSet, Dict, Iterable, List, Optional, Tuple, Type, TypeVar

import click
import grpc
Expand Down Expand Up @@ -101,24 +101,31 @@ def get_all_registered_measurement_info(
return measurement_service_classes, measurement_display_names


def get_configuration_metadata_by_index(
def get_configuration_and_output_metadata_by_index(
metadata: v2_measurement_service_pb2.GetMetadataResponse,
service_class: str,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Dict[int, ParameterMetadata]:
"""Returns the configuration metadata of the measurement."""
) -> Tuple[Dict[int, ParameterMetadata], Dict[int, ParameterMetadata]]:
"""Returns the configuration and output metadata of the measurement."""
configuration_parameter_list = []
for configuration in metadata.measurement_signature.configuration_parameters:
annotations_dict = dict(configuration.annotations.items())
if _is_enum_param(configuration.type):
annotations_dict["ni/enum.values"] = _validate_and_transform_enum_annotations(
configuration.annotations["ni/enum.values"]
)
configuration_parameter_list.append(
ParameterMetadata.initialize(
display_name=configuration.name,
type=configuration.type,
repeated=configuration.repeated,
default_value=None,
annotations=dict(configuration.annotations.items()),
annotations=annotations_dict,
message_type=configuration.message_type,
enum_type=(
_get_enum_type(configuration, enum_values_by_type)
_get_enum_type(
configuration.name, annotations_dict["ni/enum.values"], enum_values_by_type
)
if _is_enum_param(configuration.type)
else None
),
Expand All @@ -127,16 +134,23 @@ def get_configuration_metadata_by_index(

output_parameter_list = []
for output in metadata.measurement_signature.outputs:
annotations_dict = dict(output.annotations.items())
if _is_enum_param(output.type):
annotations_dict["ni/enum.values"] = _validate_and_transform_enum_annotations(
output.annotations["ni/enum.values"]
)
output_parameter_list.append(
ParameterMetadata.initialize(
display_name=output.name,
type=output.type,
repeated=output.repeated,
default_value=None,
annotations=dict(output.annotations.items()),
annotations=annotations_dict,
message_type=output.message_type,
enum_type=(
_get_enum_type(output, enum_values_by_type)
_get_enum_type(
output.name, annotations_dict["ni/enum.values"], enum_values_by_type
)
if _is_enum_param(output.type)
else None
),
Expand All @@ -150,6 +164,7 @@ def get_configuration_metadata_by_index(
pool=descriptor_pool.Default(),
)
configuration_metadata = frame_metadata_dict(configuration_parameter_list)
output_metadata = frame_metadata_dict(output_parameter_list)
deserialized_parameters = deserialize_parameters(
configuration_metadata,
metadata.measurement_signature.configuration_defaults.value,
Expand All @@ -166,33 +181,7 @@ def get_configuration_metadata_by_index(

configuration_metadata[k] = configuration_metadata[k]._replace(default_value=default_value)

return configuration_metadata


def get_output_metadata_by_index(
metadata: v2_measurement_service_pb2.GetMetadataResponse,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]] = {},
) -> Dict[int, ParameterMetadata]:
"""Returns the output metadata of the measurement."""
output_parameter_list = []
for output in metadata.measurement_signature.outputs:
output_parameter_list.append(
ParameterMetadata.initialize(
display_name=output.name,
type=output.type,
repeated=output.repeated,
default_value=None,
annotations=dict(output.annotations.items()),
message_type=output.message_type,
enum_type=(
_get_enum_type(output, enum_values_by_type)
if _is_enum_param(output.type)
else None
),
)
)
output_metadata = frame_metadata_dict(output_parameter_list)
return output_metadata
return configuration_metadata, output_metadata


def get_configuration_parameters_with_type_and_default_values(
Expand Down Expand Up @@ -229,7 +218,9 @@ def get_configuration_parameters_with_type_and_default_values(
)

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type = _get_enum_type(metadata, enum_values_by_type)
enum_type = _get_enum_type(
metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type
)
parameter_type = enum_type.__name__
if metadata.repeated:
values = []
Expand Down Expand Up @@ -280,7 +271,9 @@ def get_output_parameters_with_type(
parameter_type = f"List[{parameter_type}]"

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type_name = _get_enum_type(metadata, enum_values_by_type).__name__
enum_type_name = _get_enum_type(
metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type
).__name__
parameter_type = f"List[{enum_type_name}]" if metadata.repeated else enum_type_name

output_parameters_with_type.append(f"{parameter_name}: {parameter_type}")
Expand Down Expand Up @@ -410,16 +403,16 @@ def _is_enum_param(parameter_type: int) -> bool:


def _get_enum_type(
parameter: Any, enum_values_by_type: Dict[Type[Enum], Dict[str, int]]
parameter_name: str,
enum_annotations: str,
enum_values_by_type: Dict[Type[Enum], Dict[str, int]],
) -> Type[Enum]:
loaded_enum_values = json.loads(parameter.annotations["ni/enum.values"])
enum_values = {key: value for key, value in loaded_enum_values.items()}

enum_values = dict(json.loads(enum_annotations))
for existing_enum_type, existing_enum_values in enum_values_by_type.items():
if existing_enum_values == enum_values:
return existing_enum_type

new_enum_type_name = _get_enum_class_name(parameter.name)
new_enum_type_name = _get_enum_class_name(parameter_name)
# MyPy error: Enum() expects a string literal as the first argument.
# Ignoring this error because MyPy cannot validate dynamic Enum creation statically.
new_enum_type = Enum(new_enum_type_name, enum_values) # type: ignore[misc]
Expand All @@ -435,3 +428,22 @@ def _get_enum_class_name(name: str) -> str:
else:
name = name[0].upper() + name[1:]
return f"{name}Enum"


def _validate_and_transform_enum_annotations(enum_annotations: str) -> str:
enum_values = dict(json.loads(enum_annotations))
transformed_enum_annotations = {}
for enum_value, value in enum_values.items():
original_enum_value = enum_value

enum_value = re.sub(r"\W+", "_", enum_value)
if enum_value[0].isdigit():
enum_value = f"k_{enum_value}"

# Check for enum values that are only special characters.
if not enum_value.strip("_"):
raise click.ClickException(f"The enum value '{original_enum_value}' is invalid.")

transformed_enum_annotations[enum_value] = value

return json.dumps(transformed_enum_annotations)
36 changes: 36 additions & 0 deletions packages/generator/tests/unit/test_measurement_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
from click import ClickException

from ni_measurement_plugin_sdk_generator.client._support import (
_validate_and_transform_enum_annotations,
)


@pytest.mark.parametrize(
"enum_annotations, expected_enum_annotations",
[
('{"NONE": 0, "RED": 1, "GREEN": 2}', '{"NONE": 0, "RED": 1, "GREEN": 2}'),
(
'{"DC Volts": 0, "2-Wire Resistance": 1, "5 1/2": 2}',
'{"DC_Volts": 0, "k_2_Wire_Resistance": 1, "k_5_1_2": 2}',
),
],
)
def test___enum_annotations___validate_and_transform_enum_annotations___returns_expected_enum_annotations(
enum_annotations: str, expected_enum_annotations: str
) -> None:
actual_enum_annotations = _validate_and_transform_enum_annotations(enum_annotations)

assert actual_enum_annotations == expected_enum_annotations


def test___invalid_enum_annotations___validate_and_transform_enum_annotations___raises_invalid_enum_value_error() -> (
None
):
enum_annotations = '{"DC Volts": 0, "*": 1}'
expected_error_message = "The enum value '*' is invalid."

with pytest.raises(ClickException) as exc_info:
_ = _validate_and_transform_enum_annotations(enum_annotations)

assert exc_info.value.message == expected_error_message