Skip to content

Commit

Permalink
vq init
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Oct 13, 2023
1 parent 694c606 commit b37712f
Show file tree
Hide file tree
Showing 10 changed files with 657 additions and 40 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,6 @@ celerybeat.pid
dmypy.json

# Pyre type checker
.pyre/
.pyre/

ckpts/
12 changes: 2 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ To set up a new terminal, run:
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
To generated 50 episodes of scripted data, run:

python3 record_sim_episodes.py \
--task_name sim_transfer_cube_scripted \
--dataset_dir <data save dir> \
--num_episodes 50
python3 record_sim_episodes.py --task_name sim_transfer_cube_scripted --dataset_dir <data save dir> --num_episodes 50

To can add the flag ``--onscreen_render`` to see real-time rendering.
To visualize the episode after it is collected, run
Expand All @@ -70,12 +67,7 @@ To visualize the episode after it is collected, run
To train ACT:

# Transfer Cube task
python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
--num_epochs 2000 --lr 1e-5 \
--seed 0
python3 imitate_episodes.py --task_name sim_transfer_cube_scripted --ckpt_dir <ckpt dir> --policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 --num_epochs 2000 --lr 1e-5 --seed 0


To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
Expand Down
2 changes: 1 addition & 1 deletion constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib

### Task parameters
DATA_DIR = '<put your data dir here>'
DATA_DIR = 'data'
SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
Expand Down
1 change: 1 addition & 0 deletions detr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_args_parser():
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--vq', action='store_true')

return parser

Expand Down
73 changes: 55 additions & 18 deletions detr/models/detr_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer

Expand Down Expand Up @@ -33,7 +34,7 @@ def get_position_angle_vec(position):

class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, ):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
Expand All @@ -48,6 +49,7 @@ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, came
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
Expand All @@ -68,20 +70,24 @@ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, came
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var

print('Use VQ: ', self.vq)
if self.vq:
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
else:
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq

# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
if self.vq:
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
else:
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent

def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""

def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
# cvae encoder
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
Expand All @@ -104,15 +110,43 @@ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)

if self.vq:
logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
probs = torch.softmax(logits, dim=-1)
binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
latent_input = self.latent_out_proj(straigt_through)
mu = logvar = None
else:
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)

else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
mu = logvar = binaries = probs = None
if self.vq:
latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
else:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)


return latent_input, probs, binaries, mu, logvar

def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)

# cvae decoder
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
Expand All @@ -136,7 +170,7 @@ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
return a_hat, is_pad_hat, [mu, logvar], probs, binaries



Expand Down Expand Up @@ -247,6 +281,9 @@ def build(args):
state_dim=state_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
vq=args.vq,
vq_class=args.vq_class,
vq_dim=args.vq_dim,
)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
Expand Down
73 changes: 73 additions & 0 deletions detr/models/latent_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch.nn as nn
from torch.nn import functional as F
import torch

DROPOUT_RATE = 0.1

# a causal transformer block
class Causal_Transformer_Block(nn.Module):
def __init__(self, seq_len, latent_dim, num_head) -> None:
super().__init__()
self.num_head = num_head
self.latent_dim = latent_dim
self.ln_1 = nn.LayerNorm(latent_dim)
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True)
self.ln_2 = nn.LayerNorm(latent_dim)
self.mlp = nn.Sequential(
nn.Linear(latent_dim, 4 * latent_dim),
nn.GELU(),
nn.Linear(4 * latent_dim, latent_dim),
nn.Dropout(DROPOUT_RATE),
)

# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool())

def forward(self, x):
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
x = self.ln_1(x)
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0]
x = self.ln_2(x)
x = x + self.mlp(x)

return x

# use self-attention instead of RNN to model the latent space sequence
class Latent_Model_Transformer(nn.Module):
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.seq_len = seq_len
self.latent_dim = latent_dim
self.num_head = num_head
self.num_layer = num_layer
self.input_layer = nn.Linear(input_dim, latent_dim)
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim)
self.attention_blocks = nn.Sequential(
nn.Dropout(DROPOUT_RATE),
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)],
nn.LayerNorm(latent_dim)
)
self.output_layer = nn.Linear(latent_dim, output_dim)

def forward(self, x):
x = self.input_layer(x)
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device))
x = self.attention_blocks(x)
logits = self.output_layer(x)

return logits

@torch.no_grad()
def generate(self, n, temperature=0.1, x=None):
if x is None:
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device)
for i in range(self.seq_len):
logits = self.forward(x)[:, -1]
probs = torch.softmax(logits / temperature, dim=-1)
samples = torch.multinomial(probs, num_samples=1)[..., 0]
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float()
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1)

return x[:, 1:, :]

26 changes: 24 additions & 2 deletions imitate_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from policy import ACTPolicy, CNNMLPPolicy
from visualize_episodes import save_videos

from detr.models.latent_model import Latent_Model_Transformer

from sim_env import BOX_POSE

import IPython
Expand Down Expand Up @@ -54,6 +56,8 @@ def main(args):
enc_layers = 4
dec_layers = 7
nheads = 8
vq_class = 10
vq_dim = 10
policy_config = {'lr': args['lr'],
'num_queries': args['chunk_size'],
'kl_weight': args['kl_weight'],
Expand All @@ -65,6 +69,9 @@ def main(args):
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
'vq': args['vq'],
'vq_class': vq_class,
'vq_dim': vq_dim,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
Expand Down Expand Up @@ -161,6 +168,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'angle'
vq = config['policy_config']['vq']

# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
Expand All @@ -169,7 +177,16 @@ def eval_bc(config, ckpt_name, save_episode=True):
print(loading_status)
policy.cuda()
policy.eval()
print(f'Loaded: {ckpt_path}')
if vq:
vq_dim = config['policy_config']['vq_dim']
vq_class = config['policy_config']['vq_class']
latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
latent_model_ckpt_path = os.path.join(ckpt_dir, 'latent_model_best.ckpt')
latent_model.load_state_dict(torch.load(latent_model_ckpt_path))
latent_model.cuda()
print(f'Loaded policy from: {ckpt_path}, latent model from: {latent_model_ckpt_path}')
else:
print(f'Loaded: {ckpt_path}')
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
Expand Down Expand Up @@ -246,7 +263,11 @@ def eval_bc(config, ckpt_name, save_episode=True):
### query policy
if config['policy_class'] == "ACT":
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image)
if vq:
vq_sample = latent_model.generate(1, temperature=0.01, x=None)
all_actions = policy(qpos, curr_image, vq_sample=vq_sample)
else:
all_actions = policy(qpos, curr_image)
if temporal_agg:
all_time_actions[[t], t:t+num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
Expand Down Expand Up @@ -431,5 +452,6 @@ def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--vq', action='store_true')

main(vars(parser.parse_args()))
26 changes: 22 additions & 4 deletions policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
import torch

from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
import IPython
Expand All @@ -13,9 +14,10 @@ def __init__(self, args_override):
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
self.vq = args_override['vq']
print(f'KL Weight {self.kl_weight}')

def __call__(self, qpos, image, actions=None, is_pad=None):
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
Expand All @@ -24,22 +26,38 @@ def __call__(self, qpos, image, actions=None, is_pad=None):
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]

a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
print(vq_sample)
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(qpos, image, env_state, actions, is_pad, vq_sample)
if self.vq:
total_kld = [torch.tensor(0.0)]
loss_dict['vq_loss'] = F.l1_loss(binaries, probs, reduction='mean')
else:
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict['l1'] = l1
loss_dict['kl'] = total_kld[0]
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
a_hat, _, (_, _), _, _ = self.model(qpos, image, env_state, vq_sample=vq_sample) # no action, sample from prior
return a_hat

def configure_optimizers(self):
return self.optimizer

@torch.no_grad()
def vq_encode(self, qpos, actions, is_pad):
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]

_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)

return binaries




class CNNMLPPolicy(nn.Module):
def __init__(self, args_override):
Expand Down
Loading

0 comments on commit b37712f

Please sign in to comment.