From f66c7f58c2746df19f6411729f0b0acf0cad9e66 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 8 Feb 2024 10:39:29 +0100 Subject: [PATCH 1/8] Initialize jaxsim.mujoco package --- src/jaxsim/mujoco/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 src/jaxsim/mujoco/__init__.py diff --git a/src/jaxsim/mujoco/__init__.py b/src/jaxsim/mujoco/__init__.py new file mode 100644 index 000000000..955bea2fe --- /dev/null +++ b/src/jaxsim/mujoco/__init__.py @@ -0,0 +1,3 @@ +from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf +from .model import MujocoModelHelper +from .visualizer import MujocoVisualizer From 8fc37d596fc8dd613c9880ffa2754ddfe3855907 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 9 Feb 2024 08:26:31 +0100 Subject: [PATCH 2/8] Add jaxsim.mujoco.loaders to generate MJCF descriptions --- src/jaxsim/mujoco/loaders.py | 480 +++++++++++++++++++++++++++++++++++ 1 file changed, 480 insertions(+) create mode 100644 src/jaxsim/mujoco/loaders.py diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py new file mode 100644 index 000000000..ba0b4a55d --- /dev/null +++ b/src/jaxsim/mujoco/loaders.py @@ -0,0 +1,480 @@ +import pathlib +import tempfile +import warnings +from typing import Any + +import mujoco as mj +import rod.urdf.exporter +from lxml import etree as ET + + +def load_rod_model( + model_description: str | pathlib.Path | rod.Model, + is_urdf: bool | None = None, + model_name: str | None = None, +) -> rod.Model: + """""" + + # Parse the SDF resource. + sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) + + # Fail if the SDF resource has no model. + if len(sdf_element.models()) == 0: + raise RuntimeError("Failed to find any model in the model description") + + # Return the model if there is only one. + if len(sdf_element.models()) == 1: + if model_name is not None and sdf_element.models()[0].name != model_name: + raise ValueError(f"Model '{model_name}' not found in the description") + + return sdf_element.models()[0] + + # Require users to specify the model name if there are multiple models. + if model_name is None: + msg = "The resource has multiple models. Please specify the model name." + raise ValueError(msg) + + # Build a dictionary of models in the resource for easy access. + models = {m.name: m for m in sdf_element.models()} + + if model_name not in models: + raise ValueError(f"Model '{model_name}' not found in the resource") + + return models[model_name] + + +class RodModelToMjcf: + """""" + + @staticmethod + def assets_from_rod_model( + rod_model: rod.Model, + ) -> dict[str, bytes]: + """""" + + import resolve_robotics_uri_py + + assets_files = dict() + + for link in rod_model.links(): + for visual in link.visuals(): + if visual.geometry.mesh and visual.geometry.mesh.uri: + assets_files[visual.geometry.mesh.uri] = ( + resolve_robotics_uri_py.resolve_robotics_uri( + visual.geometry.mesh.uri + ) + ) + + for collision in link.collisions(): + if collision.geometry.mesh and collision.geometry.mesh.uri: + assets_files[collision.geometry.mesh.uri] = ( + resolve_robotics_uri_py.resolve_robotics_uri( + collision.geometry.mesh.uri + ) + ) + + assets = { + asset_name: asset.read_bytes() for asset_name, asset in assets_files.items() + } + + return assets + + @staticmethod + def add_floating_joint( + urdf_string: str, + base_link_name: str, + floating_joint_name: str = "world_to_base", + ) -> str: + """""" + + with tempfile.NamedTemporaryFile(mode="w+", suffix=".urdf") as urdf_file: + + # Write the URDF string to a temporary file and move current position + # to the beginning. + urdf_file.write(urdf_string) + urdf_file.seek(0) + + # Parse the MJCF string as XML (etree). + parser = ET.XMLParser(remove_blank_text=True) + tree = ET.parse(source=urdf_file, parser=parser) + + root: ET._Element = tree.getroot() + + if root.find(f".//joint[@name='{floating_joint_name}']") is not None: + msg = f"The URDF already has a floating joint '{floating_joint_name}'" + warnings.warn(msg) + return ET.tostring(root, pretty_print=True).decode() + + # Create the "world" link if it doesn't exist. + if root.find(".//link[@name='world']") is None: + _ = ET.SubElement(root, "link", name="world") + + # Create the floating joint. + world_to_base = ET.SubElement( + root, "joint", name=floating_joint_name, type="floating" + ) + + # Check that the base link exists. + if root.find(f".//link[@name='{base_link_name}']") is None: + raise ValueError(f"Link '{base_link_name}' not found in the URDF") + + # Attach the floating joint to the base link. + ET.SubElement(world_to_base, "parent", link="world") + ET.SubElement(world_to_base, "child", link=base_link_name) + + urdf_string = ET.tostring(root, pretty_print=True).decode() + return urdf_string + + @staticmethod + def convert( + rod_model: rod.Model, + considered_joints: list[str] | None = None, + ) -> tuple[str, dict[str, Any]]: + """""" + + # ------------------------------------- + # Convert the model description to URDF + # ------------------------------------- + + # Consider all joints if not specified otherwise. + considered_joints = set( + considered_joints + if considered_joints is not None + else [j.name for j in rod_model.joints() if j.type != "fixed"] + ) + + # If considered joints are passed, make sure that they are all part of the model. + if considered_joints - set([j.name for j in rod_model.joints()]): + extra_joints = set(considered_joints) - set( + [j.name for j in rod_model.joints()] + ) + msg = f"Couldn't find the following joints in the model: '{extra_joints}'" + raise ValueError(msg) + + # Create a dictionary of joints for quick access. + joints_dict = {j.name: j for j in rod_model.joints()} + + # Convert all the joints not considered to fixed joints. + for joint_name in set([j.name for j in rod_model.joints()]) - considered_joints: + joints_dict[joint_name].type = "fixed" + + # Convert the ROD model to URDF. + urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string( + sdf=rod.Sdf(model=rod_model, version="1.7"), + gazebo_preserve_fixed_joints=False, + pretty=True, + ) + + # ------------------------------------- + # Add a floating joint if floating-base + # ------------------------------------- + + if not rod_model.is_fixed_base(): + considered_joints |= {"world_to_base"} + urdf_string = RodModelToMjcf.add_floating_joint( + urdf_string=urdf_string, + base_link_name=rod_model.get_canonical_link(), + floating_joint_name="world_to_base", + ) + + # --------------------------------------- + # Inject the element in the URDF + # --------------------------------------- + + parser = ET.XMLParser(remove_blank_text=True) + root = ET.fromstring(text=urdf_string.encode(), parser=parser) + + mujoco_element = ( + ET.SubElement(root, "mujoco") + if len(root.findall("./mujoco")) == 0 + else root.find("./mujoco") + ) + + _ = ET.SubElement( + mujoco_element, + "compiler", + balanceinertia="true", + discardvisual="false", + ) + + urdf_string = ET.tostring(root, pretty_print=True).decode() + # print(urdf_string) + # raise + + # ------------------------------ + # Post-process all dummy visuals + # ------------------------------ + + parser = ET.XMLParser(remove_blank_text=True) + root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser) + import numpy as np + + # Give a tiny radius to all dummy spheres + for geometry in root.findall(".//visual/geometry[sphere]"): + radius = np.fromstring( + geometry.find("./sphere").attrib["radius"], sep=" ", dtype=float + ) + if np.allclose(radius, np.zeros(1)): + geometry.find("./sphere").set("radius", "0.001") + + # Give a tiny volume to all dummy boxes + for geometry in root.findall(".//visual/geometry[box]"): + size = np.fromstring( + geometry.find("./box").attrib["size"], sep=" ", dtype=float + ) + if np.allclose(size, np.zeros(3)): + geometry.find("./box").set("size", "0.001 0.001 0.001") + + urdf_string = ET.tostring(root, pretty_print=True).decode() + + # ------------------------ + # Convert the URDF to MJCF + # ------------------------ + + # Load the URDF model into Mujoco. + assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model) + mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # noqa + + # Get the joint names. + mj_joint_names = set( + [ + mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx) + for idx in range(mj_model.njnt) + ] + ) + + # Check that the Mujoco model only has the considered joints. + if mj_joint_names != considered_joints: + extra1 = mj_joint_names - considered_joints + extra2 = considered_joints - mj_joint_names + extra_joints = extra1.union(extra2) + msg = "The Mujoco model has the following extra/missing joints: '{}'" + raise ValueError(msg.format(extra_joints)) + + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".xml", prefix=f"{rod_model.name}_" + ) as mjcf_file: + + # Convert the in-memory Mujoco model to MJCF. + mj.mj_saveLastXML(mjcf_file.name, mj_model) + + # Parse the MJCF string as XML (etree). + # We need to post-process the file to include additional elements. + parser = ET.XMLParser(remove_blank_text=True) + tree = ET.parse(source=mjcf_file, parser=parser) + + # Get the root element. + root: ET._Element = tree.getroot() + + # Find the element (might be the root itself). + mujoco_elements = [e for e in root.iter("mujoco")] + mujoco_element: ET._Element = mujoco_elements[0] + + # -------------- + # Add the motors + # -------------- + + if len(mujoco_element.findall(".//actuator")) > 0: + raise RuntimeError("The model already has elements.") + + # Add the actuator element. + actuator_element = ET.SubElement(mujoco_element, "actuator") + + # Add a motor for each joint. + for joint_element in mujoco_element.findall(".//joint"): + assert ( + joint_element.attrib["name"] in considered_joints + ), joint_element.attrib["name"] + if joint_element.attrib.get("type", "hinge") in {"free", "ball"}: + continue + ET.SubElement( + actuator_element, + "motor", + name=f"{joint_element.attrib['name']}_motor", + joint=joint_element.attrib["name"], + gear="1", + ) + + # --------------------------------------------- + # Set full transparency of collision geometries + # --------------------------------------------- + + parser = ET.XMLParser(remove_blank_text=True) + + # Get all the (optional) names of the URDF collision elements + collision_names = { + c.attrib["name"] + for c in ET.fromstring(text=urdf_string.encode(), parser=parser).findall( + ".//collision[geometry]" + ) + if "name" in c.attrib + } + + # Set alpha=0 to the color of all collision elements + for geometry_element in mujoco_element.findall(".//geom[@rgba]"): + if geometry_element.attrib.get("name") in collision_names: + r, g, b, a = geometry_element.attrib["rgba"].split(" ") + geometry_element.set("rgba", f"{r} {g} {b} 0") + + # ----------------------- + # Create the scene assets + # ----------------------- + + asset_element = ( + ET.SubElement(mujoco_element, "asset") + if len(mujoco_element.findall(".//asset")) == 0 + else mujoco_element.find(".//asset") + ) + + _ = ET.SubElement( + asset_element, + "texture", + type="skybox", + builtin="gradient", + rgb1="0.3 0.5 0.7", + rgb2="0 0 0", + width="512", + height="512", + ) + + _ = ET.SubElement( + asset_element, + "texture", + name="plane_texture", + type="2d", + builtin="checker", + rgb1="0.1 0.2 0.3", + rgb2="0.2 0.3 0.4", + width="512", + height="512", + mark="cross", + markrgb=".8 .8 .8", + ) + + _ = ET.SubElement( + asset_element, + "material", + name="plane_material", + texture="plane_texture", + reflectance="0.2", + texrepeat="5 5", + texuniform="true", + ) + + # ---------------------------------- + # Populate the scene with the assets + # ---------------------------------- + + worldbody_scene_element = ET.SubElement(mujoco_element, "worldbody") + + _ = ET.SubElement( + worldbody_scene_element, + "geom", + name="floor", + type="plane", + size="0 0 0.05", + material="plane_material", + condim="3", + contype="1", + conaffinity="1", + ) + + _ = ET.SubElement( + worldbody_scene_element, + "light", + name="sun", + mode="fixed", + directional="true", + castshadow="true", + pos="0 0 10", + dir="0 0 -1", + ) + + # ------------------------------------------------ + # Add a light following the CoM of the first link + # ------------------------------------------------ + + if not rod_model.is_fixed_base(): + + worldbody_element = None + + # Find the element of our model by searching the one that contains + # all the considered joints. This is needed because there might be multiple + # elements inside . + for wb in mujoco_element.findall(".//worldbody"): + if all( + [ + wb.find(f".//joint[@name='{j}']") is not None + for j in considered_joints + ] + ): + worldbody_element = wb + break + + if worldbody_element is None: + raise RuntimeError( + "Failed to find the element of the model" + ) + + # Light attached to the model + _ = ET.SubElement( + worldbody_element, + "light", + name="light_model", + mode="targetbodycom", + target=worldbody_element.find(".//body").attrib["name"], + directional="false", + castshadow="true", + pos="1 0 5", + ) + + # -------------------------------- + # Return the resulting MJCF string + # -------------------------------- + + mjcf_string = ET.tostring(root, pretty_print=True).decode() + return mjcf_string, assets + + +class UrdfToMjcf: + @staticmethod + def convert( + urdf: str | pathlib.Path, + considered_joints: list[str] | None = None, + model_name: str | None = None, + ) -> tuple[str, dict[str, Any]]: + """""" + + # Get the ROD model. + rod_model = load_rod_model( + model_description=urdf, + is_urdf=True, + model_name=model_name, + ) + + # Convert the ROD model to MJCF. + return RodModelToMjcf.convert( + rod_model=rod_model, considered_joints=considered_joints + ) + + +class SdfToMjcf: + @staticmethod + def convert( + sdf: str | pathlib.Path, + considered_joints: list[str] | None = None, + model_name: str | None = None, + ) -> tuple[str, dict[str, Any]]: + """""" + + # Get the ROD model. + rod_model = load_rod_model( + model_description=sdf, + is_urdf=False, + model_name=model_name, + ) + + # Convert the ROD model to MJCF. + return RodModelToMjcf.convert( + rod_model=rod_model, considered_joints=considered_joints + ) From 7ec8e8b59f1b1fb9c7808495047cd1b0d062f05f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 9 Feb 2024 08:22:56 +0100 Subject: [PATCH 3/8] Add jaxsim.mujoco.model with helpers to get and set quantities --- src/jaxsim/mujoco/model.py | 352 +++++++++++++++++++++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 src/jaxsim/mujoco/model.py diff --git a/src/jaxsim/mujoco/model.py b/src/jaxsim/mujoco/model.py new file mode 100644 index 000000000..82e095f54 --- /dev/null +++ b/src/jaxsim/mujoco/model.py @@ -0,0 +1,352 @@ +import functools +import pathlib +from typing import Any + +import mujoco as mj +import numpy as np +import numpy.typing as npt +from scipy.spatial.transform import Rotation + + +class MujocoModelHelper: + """ + Helper class to create and interact with Mujoco models and data objects. + """ + + def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None: + """""" + + self.model = model + self.data = data if data is not None else mj.MjData(self.model) + + # Populate the data with kinematics + mj.mj_forward(self.model, self.data) + + # Keep the cache of this method local to improve GC + self.mask_qpos = functools.cache(self._mask_qpos) + + @staticmethod + def build_from_xml( + mjcf_description: str | pathlib.Path, assets: dict[str, Any] = None + ) -> "MujocoModelHelper": + """""" + + # Read the XML description if it's a path to file + mjcf_description = ( + mjcf_description.read_text() + if isinstance(mjcf_description, pathlib.Path) + else mjcf_description + ) + + # Create the Mujoco model from the XML and, optionally, the assets dictionary + model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa + + return MujocoModelHelper(model=model, data=mj.MjData(model)) + + def time(self) -> float: + """Return the simulation time.""" + + return self.data.time + + def timestep(self) -> float: + """Return the simulation timestep.""" + + return self.model.opt.timestep + + def gravity(self) -> npt.NDArray: + """Return the 3D gravity vector.""" + + return self.model.opt.gravity + + # ========================= + # Methods for the base link + # ========================= + + def is_floating_base(self) -> bool: + """Return true if the model is floating-base.""" + + # A body with no joints is considered a fixed-base model. + # In fact, in mujoco, a floating-base model has a 6 DoFs first joint. + if self.number_of_joints() == 0: + return False + + # We just check that the first joint has 6 DoFs. + joint0_type = self.model.jnt_type[0] + return joint0_type == mj.mjtJoint.mjJNT_FREE + + def is_fixed_base(self) -> bool: + """Return true if the model is fixed-base.""" + + return not self.is_floating_base() + + def base_link(self) -> str: + """Return the name of the base link.""" + + return mj.mj_id2name( + self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1 + ) + + def base_position(self) -> npt.NDArray: + """Return the 3D position of the base link.""" + + return ( + self.data.qpos[:3] + if self.is_floating_base() + else self.body_position(body_name=self.base_link()) + ) + + def base_orientation(self, dcm: bool = False) -> npt.NDArray: + """Return the orientation of the base link.""" + + return ( + ( + np.reshape(self.data.xmat[0], newshape=(3, 3)) + if dcm is True + else self.data.xquat[0] + ) + if self.is_floating_base() + else self.body_orientation(body_name=self.base_link(), dcm=dcm) + ) + + def set_base_position(self, position: npt.NDArray) -> None: + """Set the 3D position of the base link.""" + + if self.is_fixed_base(): + raise ValueError("The position of a fixed-base model cannot be set.") + + position = np.atleast_1d(np.array(position).squeeze()) + + if position.size != 3: + raise ValueError(f"Wrong position size ({position.size})") + + self.data.qpos[:3] = position + + def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None: + """Set the 3D position of the base link.""" + + if self.is_fixed_base(): + raise ValueError("The orientation of a fixed-base model cannot be set.") + + orientation = ( + np.atleast_2d(np.array(orientation).squeeze()) + if dcm + else np.atleast_1d(np.array(orientation).squeeze()) + ) + + if orientation.shape != ((4,) if not dcm else (3, 3)): + raise ValueError(f"Wrong orientation shape {orientation.shape}") + + def is_quaternion(Q): + return np.allclose(np.linalg.norm(Q), 1.0) + + def is_dcm(R): + return np.allclose(np.linalg.det(R), 1.0) and np.allclose( + R.T @ R, np.eye(3) + ) + + if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)): + raise ValueError("The orientation is not a valid element of SO(3)") + + W_Q_B = ( + Rotation.from_matrix(orientation).as_quat(canonical=True)[ + np.array([3, 0, 1, 2]) + ] + if dcm + else orientation + ) + + self.data.qpos[3:7] = W_Q_B + + # ================== + # Methods for joints + # ================== + + def number_of_joints(self) -> int: + """""" + + return self.model.njnt + + def number_of_dofs(self) -> int: + """""" + + return self.model.nq + + def joint_names(self) -> list[str]: + """""" + + return [ + mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx) + for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints()) + ] + + def joint_dofs(self, joint_name: str) -> int: + """""" + + if joint_name not in self.joint_names(): + raise ValueError(f"Joint '{joint_name}' not found") + + return self.data.joint(joint_name).qpos.size + + def joint_position(self, joint_name: str) -> npt.NDArray: + """""" + + if joint_name not in self.joint_names(): + raise ValueError(f"Joint '{joint_name}' not found") + + return self.data.joint(joint_name).qpos + + def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray: + """""" + + joint_names = joint_names if joint_names is not None else self.joint_names() + + return np.hstack( + [self.joint_position(joint_name) for joint_name in joint_names] + ) + + def set_joint_position( + self, joint_name: str, position: npt.NDArray | float + ) -> None: + """""" + + position = np.atleast_1d(np.array(position).squeeze()) + + if position.size != self.joint_dofs(joint_name=joint_name): + raise ValueError( + f"Wrong position size ({position.size}) of " + f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'." + ) + + idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name) + offset = self.model.jnt_qposadr[idx] + + sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)] + self.data.qpos[sl] = position + + def set_joint_positions( + self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray] + ) -> None: + """""" + + mask = self.mask_qpos(joint_names=tuple(joint_names)) + self.data.qpos[mask] = positions + + # ================== + # Methods for bodies + # ================== + + def number_of_bodies(self) -> int: + """""" + + return self.model.nbody + + def body_names(self) -> list[str]: + """""" + + return [ + mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx) + for idx in range(self.number_of_bodies()) + ] + + def body_position(self, body_name: str) -> npt.NDArray: + """""" + + if body_name not in self.body_names(): + raise ValueError(f"Body '{body_name}' not found") + + return self.data.body(body_name).xpos + + def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray: + """""" + + if body_name not in self.body_names(): + raise ValueError(f"Body '{body_name}' not found") + + return ( + self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat + ) + + # ====================== + # Methods for geometries + # ====================== + + def number_of_geometries(self) -> int: + """""" + + return self.model.ngeom + + def geometry_names(self) -> list[str]: + """""" + + return [ + mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx) + for idx in range(self.number_of_geometries()) + ] + + def geometry_position(self, geometry_name: str) -> npt.NDArray: + """""" + + if geometry_name not in self.geometry_names(): + raise ValueError(f"Geometry '{geometry_name}' not found") + + return self.data.geom(geometry_name).xpos + + def geometry_orientation( + self, geometry_name: str, dcm: bool = False + ) -> npt.NDArray: + """""" + + if geometry_name not in self.geometry_names(): + raise ValueError(f"Geometry '{geometry_name}' not found") + + R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3)) + + if dcm: + return R + + q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True) + return q_xyzw[[3, 0, 1, 2]] + + # =============== + # Private methods + # =============== + + def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray: + """ + Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array. + + Args: + joint_names: A tuple containing the names of the joints. + + Returns: + A 1D array containing the indices of the `qpos` array to access the DoFs of + the desired `joint_names`. + + Note: + This method takes a tuple of strings because we cache the output mask for + each combination of joint names. We need a hashable object for the cache. + """ + + # Get the indices of the joints in `joint_names`. + idxs = [ + mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name) + for joint_name in joint_names + ] + + # We first get the index of each joint in the qpos array, and for those that + # have multiple DoFs, we expand their mask by appending new elements. + # Finally, we flatten the list of arrays to a single array, that is the + # final qpos mask accessing all the DoFs of the desired `joint_names`. + return np.atleast_1d( + np.hstack( + [ + np.array( + [ + self.model.jnt_qposadr[idx] + i + for i in range(self.joint_dofs(joint_name=joint_name)) + ] + ) + for idx, joint_name in zip(idxs, joint_names) + ] + ).squeeze() + ) From 8710f262a3018d47be090174c003fc97304fafde Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 9 Feb 2024 08:09:45 +0100 Subject: [PATCH 4/8] Add jaxsim.mujoco.visualizer exposing the passive viewer --- src/jaxsim/mujoco/visualizer.py | 62 +++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 src/jaxsim/mujoco/visualizer.py diff --git a/src/jaxsim/mujoco/visualizer.py b/src/jaxsim/mujoco/visualizer.py new file mode 100644 index 000000000..eb897ecce --- /dev/null +++ b/src/jaxsim/mujoco/visualizer.py @@ -0,0 +1,62 @@ +import contextlib +import pathlib +from typing import ContextManager + +import mujoco as mj +import mujoco.viewer + + +class MujocoVisualizer: + """""" + + def __init__( + self, model: mj.MjModel | None = None, data: mj.MjData | None = None + ) -> None: + """""" + + self.data = data + self.model = model + + def sync( + self, + viewer: mujoco.viewer.Handle, + model: mj.MjModel | None = None, + data: mj.MjData | None = None, + ) -> None: + """""" + + data = data if data is not None else self.data + model = model if model is not None else self.model + + mj.mj_forward(model, data) + viewer.sync() + + def open_viewer( + self, model: mj.MjModel | None = None, data: mj.MjData | None = None + ) -> mj.viewer.Handle: + """""" + + data = data if data is not None else self.data + model = model if model is not None else self.model + + handle = mj.viewer.launch_passive( + model, data, show_left_ui=False, show_right_ui=False + ) + + return handle + + @contextlib.contextmanager + def open( + self, + model: mj.MjModel | None = None, + data: mj.MjData | None = None, + close_on_exit: bool = True, + ) -> ContextManager[mujoco.viewer.Handle]: + """""" + + handle = self.open_viewer(model=model, data=data) + + try: + yield handle + finally: + handle.close() if close_on_exit else None From c1239539067d81cc75afb94b325c32504609b181 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 9 Feb 2024 08:07:46 +0100 Subject: [PATCH 5/8] Add jaxsim.mujoco cmdline --- src/jaxsim/mujoco/__main__.py | 191 ++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 src/jaxsim/mujoco/__main__.py diff --git a/src/jaxsim/mujoco/__main__.py b/src/jaxsim/mujoco/__main__.py new file mode 100644 index 000000000..0d6bb3008 --- /dev/null +++ b/src/jaxsim/mujoco/__main__.py @@ -0,0 +1,191 @@ +import argparse +import pathlib +import time + +import numpy as np + +from . import MujocoModelHelper, MujocoVisualizer, SdfToMjcf, UrdfToMjcf + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + prog="jaxsim.mujoco", + description="Process URDF and SDF files for Mujoco usage.", + ) + + parser.add_argument( + "-d", + "--description", + required=True, + metavar="INPUT_FILE", + type=pathlib.Path, + help="Path to the URDF or SDF file.", + ) + + parser.add_argument( + "-m", + "--model-name", + metavar="NAME", + type=str, + default=None, + help="The target model of a SDF description if multiple models exists.", + ) + + parser.add_argument( + "-e", + "--export", + metavar="MJCF_FILE", + type=pathlib.Path, + default=None, + help="Path to the exported MJCF file.", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Override the output MJCF file if it already exists (default: %(default)s).", + ) + + parser.add_argument( + "-p", + "--print", + action="store_true", + default=False, + help="Print in the stdout the exported MJCF string (default: %(default)s).", + ) + + parser.add_argument( + "-v", + "--visualize", + action="store_true", + default=False, + help="Visualize the description in the Mujoco viewer (default: %(default)s).", + ) + + parser.add_argument( + "-b", + "--base-position", + metavar=("x", "y", "z"), + nargs=3, + type=float, + default=None, + help="Override the base position (supports only floating-base models).", + ) + + parser.add_argument( + "-q", + "--base-quaternion", + metavar=("w", "x", "y", "z"), + nargs=4, + type=float, + default=None, + help="Override the base quaternion (supports only floating-base models).", + ) + + args = parser.parse_args() + + # ================== + # Validate arguments + # ================== + + # Expand the path of the URDF/SDF file if not absolute. + if args.description is not None: + args.description = ( + ( + args.description + if args.description.is_absolute() + else pathlib.Path.cwd() / args.description + ) + .expanduser() + .absolute() + ) + + if not pathlib.Path(args.description).is_file(): + msg = f"The URDF/SDF file '{args.description}' does not exist." + parser.error(msg) + exit(1) + + # Expand the path of the output MJCF file if not absolute. + if args.export is not None: + args.export = ( + ( + args.export + if args.export.is_absolute() + else pathlib.Path.cwd() / args.export + ) + .expanduser() + .absolute() + ) + + if pathlib.Path(args.export).is_file() and not args.force: + msg = "The output file '{}' already exists, use '--force' to override." + parser.error(msg.format(args.export)) + exit(1) + + # ================================================ + # Load the URDF/SDF file and produce a MJCF string + # ================================================ + + match args.description.suffix.lower()[1:]: + + case "urdf": + mjcf_string, assets = UrdfToMjcf().convert(urdf=args.description) + + case "sdf": + mjcf_string, assets = SdfToMjcf().convert( + sdf=args.description, model_name=args.model_name + ) + + case _: + msg = f"The file extension '{args.description.suffix}' is not supported." + parser.error(msg) + exit(1) + + if args.print: + print(mjcf_string, flush=True) + + # ======================================== + # Write the MJCF string to the output file + # ======================================== + + if args.export is not None: + with open(args.export, "w+") as file: + file.write(mjcf_string) + + # ======================================= + # Visualize the MJCF in the Mujoco viewer + # ======================================= + + if args.visualize: + + mj_model_helper = MujocoModelHelper.build_from_xml( + mjcf_description=mjcf_string, assets=assets + ) + + viz = MujocoVisualizer(model=mj_model_helper.model, data=mj_model_helper.data) + + with viz.open() as viewer: + + with viewer.lock(): + if args.base_position is not None: + mj_model_helper.set_base_position( + position=np.array(args.base_position) + ) + + if args.base_quaternion is not None: + mj_model_helper.set_base_orientation( + orientation=np.array(args.base_quaternion) + ) + + viz.sync(viewer=viewer) + + while viewer.is_running(): + time.sleep(0.500) + + # ============================= + # Exit the program with success + # ============================= + + exit(0) From b034884cfe08cbd2cc89c6c1153dd9be01166c02 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 9 Feb 2024 17:26:51 +0100 Subject: [PATCH 6/8] Add MujocoVideoRecorder class --- src/jaxsim/mujoco/visualizer.py | 90 +++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/jaxsim/mujoco/visualizer.py b/src/jaxsim/mujoco/visualizer.py index eb897ecce..dac8d1a75 100644 --- a/src/jaxsim/mujoco/visualizer.py +++ b/src/jaxsim/mujoco/visualizer.py @@ -2,8 +2,98 @@ import pathlib from typing import ContextManager +import mediapy as media import mujoco as mj import mujoco.viewer +import numpy.typing as npt + + +class MujocoVideoRecorder: + """""" + + def __init__( + self, + model: mj.MjModel, + data: mj.MjData, + fps: int = 30, + width: int | None = None, + height: int | None = None, + **kwargs, + ) -> None: + """""" + + width = width if width is not None else model.vis.global_.offwidth + height = height if height is not None else model.vis.global_.offheight + + if model.vis.global_.offwidth != width: + model.vis.global_.offwidth = width + + if model.vis.global_.offheight != height: + model.vis.global_.offheight = height + + self.fps = fps + self.frames: list[npt.NDArray] = [] + self.data: mujoco.MjData | None = None + self.model: mujoco.MjModel | None = None + self.reset(model=model, data=data) + + self.renderer = mujoco.Renderer( + model=self.model, + **(dict(width=width, height=height) | kwargs), + ) + + def reset( + self, model: mj.MjModel | None = None, data: mj.MjData | None = None + ) -> None: + """""" + + self.frames = [] + + self.data = data if data is not None else self.data + self.model = model if model is not None else self.model + + def render_frame(self, camera_name: str | None = None) -> None: + """""" + + mujoco.mj_forward(self.model, self.data) + self.renderer.update_scene(data=self.data) # TODO camera name + + self.frames.append(self.renderer.render()) + + def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None: + """""" + + if path.is_dir(): + raise IsADirectoryError(f"The path '{path}' is a directory.") + + if not exist_ok and path.is_file(): + raise FileExistsError(f"The file '{path}' already exists.") + + media.write_video(path=path, images=self.frames, fps=self.fps) + + @staticmethod + def compute_down_sampling(original_fps: int, target_min_fps: int) -> int: + """ + Return the integer down-sampling factor to reach at least the target fps. + + Args: + original_fps: The original fps. + target_min_fps: The target minimum fps. + + Returns: + The down-sampling factor. + """ + + down_sampling = 1 + down_sampling_final = down_sampling + + while original_fps / (down_sampling + 1) >= target_min_fps: + down_sampling = down_sampling + 1 + + if int(original_fps / down_sampling) == original_fps / down_sampling: + down_sampling_final = down_sampling + + return down_sampling_final class MujocoVisualizer: From 5561c7ca0b4109daad5107fefce5c5a6f2a328f3 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 12 Feb 2024 10:51:21 +0100 Subject: [PATCH 7/8] Add visualization dependencies --- environment.yml | 5 +++++ setup.cfg | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/environment.yml b/environment.yml index a5f02f31d..61ee7d40d 100644 --- a/environment.yml +++ b/environment.yml @@ -12,13 +12,18 @@ dependencies: - pptree - rod # Optional dependencies from setup.cfg + # [style] - black - isort + # [testing] - idyntree - pytest - pytest-forked - pytest-icdiff - robot_descriptions + # [viz] + - mediapy + - mujoco >= 3.0.0 # System dependencies to run the tests - gz-sim7 # Other packages diff --git a/setup.cfg b/setup.cfg index c4e5c1b91..fa3c714f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,6 +73,10 @@ testing = pytest-forked pytest-icdiff robot-descriptions +viz = + mediapy + mujoco >= 3.0.0 all = %(style)s %(testing)s + %(viz)s From 4dcd4a760ac2d9afe7d6a74fea385f2447ebb07f Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Mon, 12 Feb 2024 11:00:11 +0100 Subject: [PATCH 8/8] Address review Co-authored-by: Filippo Luca Ferretti --- src/jaxsim/mujoco/__main__.py | 11 ++++++----- src/jaxsim/mujoco/loaders.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/jaxsim/mujoco/__main__.py b/src/jaxsim/mujoco/__main__.py index 0d6bb3008..b072b9148 100644 --- a/src/jaxsim/mujoco/__main__.py +++ b/src/jaxsim/mujoco/__main__.py @@ -1,5 +1,6 @@ import argparse import pathlib +import sys import time import numpy as np @@ -105,7 +106,7 @@ if not pathlib.Path(args.description).is_file(): msg = f"The URDF/SDF file '{args.description}' does not exist." parser.error(msg) - exit(1) + sys.exit(1) # Expand the path of the output MJCF file if not absolute. if args.export is not None: @@ -122,7 +123,7 @@ if pathlib.Path(args.export).is_file() and not args.force: msg = "The output file '{}' already exists, use '--force' to override." parser.error(msg.format(args.export)) - exit(1) + sys.exit(1) # ================================================ # Load the URDF/SDF file and produce a MJCF string @@ -141,7 +142,7 @@ case _: msg = f"The file extension '{args.description.suffix}' is not supported." parser.error(msg) - exit(1) + sys.exit(1) if args.print: print(mjcf_string, flush=True) @@ -151,7 +152,7 @@ # ======================================== if args.export is not None: - with open(args.export, "w+") as file: + with open(args.export, "w+", encoding="utf-8") as file: file.write(mjcf_string) # ======================================= @@ -188,4 +189,4 @@ # Exit the program with success # ============================= - exit(0) + sys.exit(0) diff --git a/src/jaxsim/mujoco/loaders.py b/src/jaxsim/mujoco/loaders.py index ba0b4a55d..1e612a9e7 100644 --- a/src/jaxsim/mujoco/loaders.py +++ b/src/jaxsim/mujoco/loaders.py @@ -155,7 +155,7 @@ def convert( joints_dict = {j.name: j for j in rod_model.joints()} # Convert all the joints not considered to fixed joints. - for joint_name in set([j.name for j in rod_model.joints()]) - considered_joints: + for joint_name in set(j.name for j in rod_model.joints()) - considered_joints: joints_dict[joint_name].type = "fixed" # Convert the ROD model to URDF. @@ -237,10 +237,8 @@ def convert( # Get the joint names. mj_joint_names = set( - [ - mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx) - for idx in range(mj_model.njnt) - ] + mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx) + for idx in range(mj_model.njnt) ) # Check that the Mujoco model only has the considered joints. @@ -267,8 +265,7 @@ def convert( root: ET._Element = tree.getroot() # Find the element (might be the root itself). - mujoco_elements = [e for e in root.iter("mujoco")] - mujoco_element: ET._Element = mujoco_elements[0] + mujoco_element: ET._Element = list(root.iter("mujoco"))[0] # -------------- # Add the motors @@ -403,10 +400,8 @@ def convert( # elements inside . for wb in mujoco_element.findall(".//worldbody"): if all( - [ - wb.find(f".//joint[@name='{j}']") is not None - for j in considered_joints - ] + wb.find(f".//joint[@name='{j}']") is not None + for j in considered_joints ): worldbody_element = wb break