Skip to content

Commit

Permalink
Document the SDFRandomizer class
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Apr 26, 2020
1 parent af2c664 commit ed0c1cb
Showing 1 changed file with 140 additions and 0 deletions.
140 changes: 140 additions & 0 deletions python/gym_ignition/randomizers/model/sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,46 @@ class UniformParams(NamedTuple):


class RandomizationDataBuilder:
"""
Builder class of a :py:class:`~gym_ignition.randomizers.model.sdf.RandomizationData`
object.
Args:
randomizer: The :py:class:`~gym_ignition.randomizers.model.sdf.SDFRandomizer`
object to which the created randomization will be inserted.
"""

def __init__(self, randomizer: "SDFRandomizer"):

self.storage: Dict = {}
self.randomizer = randomizer

def at_xpath(self, xpath: str) -> "RandomizationDataBuilder":
"""
Set the XPath pattern associated to the randomization.
Args:
xpath: The XPath pattern.
Returns:
The randomization builder to allow chaining methods.
"""
self.storage["xpath"] = xpath
return self

def sampled_from(self,
distribution: Distribution,
parameters: DistributionParameters) -> "RandomizationDataBuilder":
"""
Set the distribution associated to the randomization.
Args:
distribution: The desired distribution.
parameters: The namedtuple with the parameters of the distribution.
Returns:
The randomization builder to allow chaining methods.
"""

self.storage["distribution"] = distribution
self.storage["parameters"] = parameters
Expand All @@ -72,19 +99,62 @@ def sampled_from(self,
return self

def method(self, method: Method) -> "RandomizationDataBuilder":
"""
Set the randomization method.
Args:
method: The desired randomization method.
Returns:
The randomization builder to allow chaining methods.
"""

self.storage["method"] = method
return self

def ignore_zeros(self, ignore_zeros: bool) -> "RandomizationDataBuilder":
"""
Ignore the randomization of values that are zero.
If the value to randomize has a default value of 0 in the SDF, when this method
is chained the randomization is skipped. In the case of a multi-match XPath
pattern, the values that are not zero are not skipped.
Args:
ignore_zeros: True if zeros should be ignored, false otherwise.
Returns:
The randomization builder to allow chaining methods.
"""


self.storage["ignore_zeros"] = ignore_zeros
return self

def force_positive(self, force_positive: bool = True) -> "RandomizationDataBuilder":
"""
Force the randomized value to be greater than zero.
This option is helpful to enforce that values e.g. the mass will stay positive
regardless of the applied distribution parameters.
Args:
force_positive: True to force positive parameters, false otherwise.
Returns:
The randomization builder to allow chaining methods.
"""

self.storage["force_positive"] = force_positive
return self

def add(self) -> None:
"""
Close the chaining of methods are return to the SDF randomizer the configuration.
Raises:
RuntimeError: If the XPath pattern does not find any match in the SDF.
"""

data = RandomizationData(**self.storage)

Expand All @@ -95,6 +165,15 @@ def add(self) -> None:


class SDFRandomizer:
"""
Randomized SDF files generator.
Args:
sdf_model: The absolute path to the SDF file.
Raises:
ValueError: If the SDF file does not exist.
"""

def __init__(self, sdf_model: str):

Expand All @@ -117,12 +196,36 @@ def __init__(self, sdf_model: str):
self.rng = np.random.default_rng()

def seed(self, seed: int) -> None:
"""
Seed the SDF randomizer.
Args:
seed: The seed number.
"""
self.rng = np.random.default_rng(seed)

def find_xpath(self, xpath: str) -> List[etree.Element]:
"""
Find the elements that match an XPath pattern.
This method could be helpful to test the matches of a XPath pattern before using
it in :py:meth:`~gym_ignition.randomizers.model.sdf.RandomizationDataBuilder.at_xpath`.
Args:
xpath: The XPath pattern.
Return:
A list of elements matching the XPath pattern.
"""
return self._root.findall(xpath)

def process_data(self) -> None:
"""
Process all the inserted randomizations.
Raises:
RuntimeError: If the XPath of a randomization has no matches.
"""

# Since we support multi-match XPaths, we expand all the individual matches
expanded_randomizations = []
Expand Down Expand Up @@ -160,6 +263,19 @@ def process_data(self) -> None:
self._randomizations = expanded_randomizations

def sample(self, pretty_print=False) -> str:
"""
Sample a randomized SDF string.
Args:
pretty_print: True to pretty print the output.
Raises:
ValueError: If the distribution of a randomization is not recognized.
ValueError: If the method of a randomization is not recognized.
Returns:
The randomized model as SDF string.
"""

for data in self._randomizations:

Expand Down Expand Up @@ -200,15 +316,39 @@ def sample(self, pretty_print=False) -> str:
return etree.tostring(self._root, pretty_print=pretty_print).decode()

def new_randomization(self) -> RandomizationDataBuilder:
"""
Start the chaining to build a new randomization.
Return:
A randomization builder.
"""
return RandomizationDataBuilder(randomizer=self)

def insert(self, randomization_data) -> None:
"""
Insert a randomization.
Args:
randomization_data: A new randomization.
"""
self._randomizations.append(randomization_data)

def get_active_randomizations(self) -> List[RandomizationData]:
"""
Return the active randomizations.
This method could be helpful also in the case of multi-match XPath patterns to
validate that the inserted randomizations have been processed correctly.
Returns:
The list of the active randomizations.
"""
return self._randomizations

def clean(self) -> None:
"""
Clean the SDF randomizer.
"""

self._randomizations = []
self._default_values = {}
Expand Down

0 comments on commit ed0c1cb

Please sign in to comment.