Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The inference results with trained checkpoint (RT1main) look so weired. #25

Open
ka2hyeon opened this issue Nov 10, 2023 · 2 comments
Open

Comments

@ka2hyeon
Copy link

ka2hyeon commented Nov 10, 2023

After downloading rt1 dataset here, I did inference the dataset with your checkpoint (RT1 main). However, the inference result is far from that in the dataset. When plotted, the shape of the trajectories looks similar between the dataset and inference, but scale shift and drift are observed in the inference. Even more, the training loss reduces greatly when I additionally train the checkpoint with the dataset. It seems that your checkpoint is trained with the different dataset from what you shared, or maybe I am doing wrong inference.

I appreciate your efforts of large-scale robot learning, and I want to reproduce your result. Could you give me any comments about the mismatch between the inference result and GT in your dataset?

Here is psuedo-code of what I did:

# Data source loading
import tensorflow_datasets as tfds
builder = tfds.builder_from_directory(builder_dir=[RT1 data dir])
ds = builder.as_data_source(split=split, decoders=tfds.decode.SkipDecoding())
example = tf.train.Example.FromString(ds.data_source[episode_id])
features = example.features.feature

images = tf.io.decode_image(features["steps/observations/images"].bytes_list.value)
rotation_delta = tf.constant(features["steps/acitons/rotation_delta"].float_list.value, dtype=np.float32)
...
observations = {'images': images, ...}
actions_gt = {'actions': rotation_delta , ...}

# Inference
from tf_agents.trajectories.time_step import TimeStep
time_step = TimeStep(
        observation=observations,
        ...
    )
policy = tf.saved_model.load(pb_path_of_rt1main)
policy_state = policy.get_initial_state(batch_size=1)
action_step = policy.action(time_step, policy_state)

# !! action_step is different_from actions_gt
@KzZheng
Copy link

KzZheng commented Jan 22, 2024

Sorry to bother you, but I wonder how do you successfully load the data. I also downloaded the data, but the file names in the ds.data_source are not included inside the dataset, eg. 'RT_1_paper_release-train.array_record-00000-of-01024'. Do you know how to fix it? Thanks!

@safsin
Copy link

safsin commented Feb 16, 2024

Sorry to bother you, but I wonder how do you successfully load the data. I also downloaded the data, but the file names in the ds.data_source are not included inside the dataset, eg. 'RT_1_paper_release-train.array_record-00000-of-01024'. Do you know how to fix it? Thanks!

Refer to ipynb for loading dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants