Skip to content

Commit

Permalink
Merge pull request #1 from ARISE-Initiative/master
Browse files Browse the repository at this point in the history
fix scan modality (ARISE-Initiative#119)
  • Loading branch information
Dhanushvarma authored Dec 13, 2023
2 parents aa0da68 + b5ce962 commit 7df309c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
5 changes: 5 additions & 0 deletions robomimic/models/base_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,11 @@ def __init__(

# Get activation requested
activation = CONV_ACTIVATIONS[activation]

# Add layer kwargs
conv_kwargs["out_channels"] = out_channels
conv_kwargs["kernel_size"] = kernel_size
conv_kwargs["stride"] = stride

# Generate network
self.n_layers = len(out_channels)
Expand Down
3 changes: 2 additions & 1 deletion robomimic/models/obs_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ def __init__(
conv_kwargs = dict()

# Generate backbone network
# N input channels is assumed to be the first dimension
self.backbone = BaseNets.Conv1dBase(
input_channel=1,
input_channel=self.input_shape[0],
activation=conv_activation,
**conv_kwargs,
)
Expand Down
24 changes: 24 additions & 0 deletions robomimic/utils/obs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,10 +944,34 @@ class ScanModality(Modality):

@classmethod
def _default_obs_processor(cls, obs):
# Channel swaps ([...,] L, C) --> ([...,] C, L)

# First, add extra dimension at 2nd to last index to treat this as a frame
shape = obs.shape
new_shape = [*shape[:-2], 1, *shape[-2:]]
obs = obs.reshape(new_shape)

# Convert shape
obs = batch_image_hwc_to_chw(obs)

# Remove extra dimension (it's the second from last dimension)
obs = obs.squeeze(-2)
return obs

@classmethod
def _default_obs_unprocessor(cls, obs):
# Channel swaps ([B,] C, L) --> ([B,] L, C)

# First, add extra dimension at 1st index to treat this as a frame
shape = obs.shape
new_shape = [*shape[:-2], 1, *shape[-2:]]
obs = obs.reshape(new_shape)

# Convert shape
obs = batch_image_chw_to_hwc(obs)

# Remove extra dimension (it's the second from last dimension)
obs = obs.squeeze(-2)
return obs


Expand Down

0 comments on commit 7df309c

Please sign in to comment.