From fb3b453a1d25281888e38cd33b11ab6e2f1bdcf8 Mon Sep 17 00:00:00 2001 From: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:13:43 +0000 Subject: [PATCH] add compatibility with ray and tensorflow --- .../py/ns3ai_gym_env/envs/ns3_environment.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py b/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py index cbbbe39..6ff921a 100644 --- a/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py +++ b/model/gym-interface/py/ns3ai_gym_env/envs/ns3_environment.py @@ -25,13 +25,13 @@ def _create_space(self, spaceDesc): mtype = boxSpacePb.dtype if mtype == pb.INT: - mtype = np.int + mtype = int elif mtype == pb.UINT: - mtype = np.uint + raise NotImplementedError("uint is not supported by all rl frameworks. Use int instead!") elif mtype == pb.DOUBLE: - mtype = np.float + mtype = np.float64 else: - mtype = np.float + mtype = np.float32 space = spaces.Box(low=low, high=high, shape=shape, dtype=mtype) @@ -203,8 +203,7 @@ def _pack_data(self, actions, spaceDesc): boxContainerPb.intData.extend(actions) elif spaceDesc.dtype in ['uint', 'uint8', 'uint16', 'uint32', 'uint64']: - boxContainerPb.dtype = pb.UINT - boxContainerPb.uintData.extend(actions) + raise NotImplementedError("uint is not supported by all rl frameworks. Use int instead!") elif spaceDesc.dtype in ['float', 'float32', 'float64']: boxContainerPb.dtype = pb.FLOAT @@ -274,6 +273,8 @@ def get_state(self): def __init__(self, targetName, ns3Path, ns3Settings=None, shmSize=4096): if self._created: raise Exception('Error: Ns3Env is singleton') + self.targetName = targetName + self.shmSize = shmSize self._created = True self.exp = Experiment(targetName, ns3Path, py_binding, shmSize=shmSize) self.ns3Settings = ns3Settings @@ -336,3 +337,16 @@ def close(self): self.exp.kill() # destroy the message interface and its shared memory segment del self.exp + + def __getstate__(self): + return { + "targetName": self.targetName, + "ns3Path": ".", + "ns3Settings": self.ns3Settings, + "shmSize": self.shmSize, + } + + def __setstate__(self, state): + if hasattr(self, "exp"): + self.close() + self.__init__(**state)