diff --git a/src/pseudopeople/configuration/entities.py b/src/pseudopeople/configuration/entities.py index d7e800c1..c216200c 100644 --- a/src/pseudopeople/configuration/entities.py +++ b/src/pseudopeople/configuration/entities.py @@ -6,6 +6,5 @@ class Keys: PROBABILITY = "probability" CELL_PROBABILITY = "cell_probability" TOKEN_PROBABILITY = "token_probability" - INCLUDE_ORIGINAL_TOKEN_PROBABILITY = "include_original_token_probability" POSSIBLE_AGE_DIFFERENCES = "possible_age_differences" ZIPCODE_DIGIT_PROBABILITIES = "digit_probabilities" diff --git a/src/pseudopeople/configuration/validator.py b/src/pseudopeople/configuration/validator.py index 8b2e8f02..50b2f9ff 100644 --- a/src/pseudopeople/configuration/validator.py +++ b/src/pseudopeople/configuration/validator.py @@ -66,7 +66,7 @@ def _validate_noise_type_config( parameter_config_validator = { Keys.POSSIBLE_AGE_DIFFERENCES: _validate_possible_age_differences, Keys.ZIPCODE_DIGIT_PROBABILITIES: _validate_zipcode_digit_probabilities, - }.get(parameter, _validate_standard_parameters) + }.get(parameter, _validate_probability) _ = _get_default_config_node( default_noise_type_config, parameter, "parameter", dataset, column, noise_type @@ -161,10 +161,10 @@ def _validate_zipcode_digit_probabilities( f"{len(noise_type_config)} probabilities ({noise_type_config})." ) for value in noise_type_config: - _validate_standard_parameters(value, parameter, base_error_message) + _validate_probability(value, parameter, base_error_message) -def _validate_standard_parameters( +def _validate_probability( noise_type_config: Union[int, float], parameter: str, base_error_message: str ) -> None: if not isinstance(noise_type_config, (float, int)): diff --git a/src/pseudopeople/entity_types.py b/src/pseudopeople/entity_types.py index 51f6b53f..80d6d014 100644 --- a/src/pseudopeople/entity_types.py +++ b/src/pseudopeople/entity_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional import pandas as pd from loguru import logger @@ -25,7 +25,7 @@ class RowNoiseType: """ name: str - noise_function: Callable[[pd.DataFrame, ConfigTree, RandomnessStream], pd.DataFrame] + noise_function: Callable[[str, pd.DataFrame, ConfigTree, RandomnessStream], pd.DataFrame] probability: float = 0.0 def __call__( @@ -56,7 +56,7 @@ class ColumnNoiseType: name: str noise_function: Callable[[pd.Series, ConfigTree, RandomnessStream, Any], pd.Series] - probability: float = 0.01 + probability: Optional[float] = 0.01 noise_level_scaling_function: Callable[[str], float] = lambda x: 1.0 additional_parameters: Dict[str, Any] = None diff --git a/src/pseudopeople/noise_entities.py b/src/pseudopeople/noise_entities.py index bf481ecb..6b0dac12 100644 --- a/src/pseudopeople/noise_entities.py +++ b/src/pseudopeople/noise_entities.py @@ -92,7 +92,6 @@ class __NoiseTypes(NamedTuple): additional_parameters={ # TODO: need to clarify these Keys.CELL_PROBABILITY: 0.01, Keys.TOKEN_PROBABILITY: 0.1, - Keys.INCLUDE_ORIGINAL_TOKEN_PROBABILITY: 0.1, }, ) diff --git a/src/pseudopeople/noise_functions.py b/src/pseudopeople/noise_functions.py index b545235a..5f19334c 100644 --- a/src/pseudopeople/noise_functions.py +++ b/src/pseudopeople/noise_functions.py @@ -374,7 +374,8 @@ def keyboard_corrupt(truth, corrupted_pr, addl_pr, rng): return err token_noise_level = configuration[Keys.TOKEN_PROBABILITY] - include_token_probability_level = configuration[Keys.INCLUDE_ORIGINAL_TOKEN_PROBABILITY] + # TODO: remove this hard-coding + include_token_probability_level = 0.1 rng = np.random.default_rng(seed=randomness_stream.seed) column = column.astype(str) diff --git a/tests/unit/test_column_noise.py b/tests/unit/test_column_noise.py index 7d316baa..5741eba9 100644 --- a/tests/unit/test_column_noise.py +++ b/tests/unit/test_column_noise.py @@ -569,7 +569,6 @@ def test_generate_typographical_errors(dummy_dataset, column): NOISE_TYPES.typographic.name: { Keys.CELL_PROBABILITY: 0.1, Keys.TOKEN_PROBABILITY: 0.1, - Keys.INCLUDE_ORIGINAL_TOKEN_PROBABILITY: 0.1, }, }, }, @@ -598,7 +597,8 @@ def test_generate_typographical_errors(dummy_dataset, column): # Check for expected string growth due to keeping original noised token assert (check_noised.str.len() >= check_original.str.len()).all() - p_include_original_token = config[Keys.INCLUDE_ORIGINAL_TOKEN_PROBABILITY] + # TODO: remove this hard-coding + p_include_original_token = 0.1 p_token_does_not_increase_string_length = 1 - p_token_noise * p_include_original_token p_strings_do_not_increase_length = ( p_token_does_not_increase_string_length**str_lengths