Skip to content

Commit

Permalink
add notes for variable length sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Mar 14, 2024
1 parent d28e1b0 commit a78a9eb
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
"""
hidden_states: (B, L, D)
cu_seqlens: one-dimensional tensor like flash-attn varlen API, only used for variable-length sequences and packing variable-length sequences into one, a.k.a., batch_size B=1
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
Expand Down Expand Up @@ -157,7 +158,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None,
cu_seqlens=cu_seqlens,
)
else:
x, z = xz.chunk(2, dim=1)
Expand All @@ -166,12 +167,12 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
if cu_seqlens is not None:
padded_x = x
count = 0
for idx in cu_seqlens[0][1:-1].tolist():
for idx in cu_seqlens[1:-1].tolist():
padded_idx = idx + count*(self.d_conv - 1)
padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2)
count = count + 1
x = padded_x
assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[0][1:-1]) + z.shape[2]
# assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2]

# Compute short convolution
if conv_state is not None:
Expand All @@ -192,13 +193,13 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
if cu_seqlens is not None:
mask = []
for seq_len in (cu_seqlens[0][1:] - cu_seqlens[0][:-1]).tolist():
for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist():
mask.extend([True] * seq_len)
mask.extend([False] * (self.d_conv - 1))
mask = mask[:-(self.d_conv - 1)]
assert x.shape[2] == len(mask)
# assert x.shape[2] == len(mask)
x = x[:, :, mask]
assert x.shape[2] == z.shape[2]
# assert x.shape[2] == z.shape[2]

# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
Expand All @@ -222,7 +223,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None,
cu_seqlens=cu_seqlens,
)
if ssm_state is not None:
y, last_state = y
Expand Down

0 comments on commit a78a9eb

Please sign in to comment.