Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable input sequence length #236

Open
ruoyxue opened this issue Mar 12, 2024 · 11 comments
Open

Variable input sequence length #236

ruoyxue opened this issue Mar 12, 2024 · 11 comments

Comments

@ruoyxue
Copy link

ruoyxue commented Mar 12, 2024

Dear authors,
This is an amazing work!
I'm working with variable sequence lengths of video data. In one batch, there could be several videos with different frame numbers, and they will be padded to the same length. When I use transformer, I use attention masks to solve the problem of variable input lengths, but I do not see a similar mask in the Mamba forward function. Is there any solutions for dealing with variable lengths in a batch when using Mamba? Thanks!

@ruoyxue
Copy link
Author

ruoyxue commented Mar 13, 2024

What's more, I'm using bidirectional Mamba, is there any solutions for processing data with variable lengths in a batch?

@tridao
Copy link
Collaborator

tridao commented Mar 13, 2024

Variable length is not currently implemented but will be in the future. For now you can pad your sequences.

@EricPaul03
Copy link

I want to know what actions I should take on my sequence to achieve the entire pad process. For example, if my sequence length is 128 and I want to pad to a length of 200, can I directly cut it back to the first 128 before the skip connection of Mamba? The output of the last 72 dimensions will not interfere with my results, right?

@tridao
Copy link
Collaborator

tridao commented Mar 13, 2024

Yes padding tokens should be on the right.

@season0528
Copy link
Contributor

season0528 commented Mar 14, 2024

Hi @tridao @ruoyxue @EricPaul03

For training on variable-length sequence, it is okay to right padding to the maximum length with zeros as a alternative solution. But padded tokens might have some side effect on the hardware utilization because computing resources would be wasted on the meaningless padded tokens.

We have supported variable-length sequence for mamba block. #244

Hope it helps!

@ZJEast
Copy link

ZJEast commented Mar 17, 2024

any progress ? I want to use variable-length sequences, too. And test the performance of Mamba on some Reinforcement Learning tasks.

@ZJEast
Copy link

ZJEast commented Mar 17, 2024

I think it might be an alternative solution to use "gather" to solve this problem. Like

    def hidden_state(self, input_ids: Tensor, input_len: Tensor):
        hs: Tensor = self.backbone(input_ids, inference_params=None)
        B, L, D = hs.shape
        l = (input_len - 1).view(B, 1, 1)
        hs = hs.gather(1, l.expand(B, 1, D))
        return hs.view(B, D)

Here "input_len" is the length of the sequence.

@EricPaul03
Copy link

I think it might be an alternative solution to use "gather" to solve this problem. Like

    def hidden_state(self, input_ids: Tensor, input_len: Tensor):
        hs: Tensor = self.backbone(input_ids, inference_params=None)
        B, L, D = hs.shape
        l = (input_len - 1).view(B, 1, 1)
        hs = hs.gather(1, l.expand(B, 1, D))
        return hs.view(B, D)

Here "input_len" is the length of the sequence.

Hello, thank you very much for your answer. I don't quite understand your code yet. Can you tell me where to put this code and how to understand it? Thank you again!

@ZJEast
Copy link

ZJEast commented Mar 19, 2024

Because Mamba is a RNN architecture, the output of the padding words is dirty, what you need to do is just to drop the dirty output and prevent them from inference. I think it can be done in many various ways in pytorch. The padding should be on the right.

It really works for me.

@EricPaul03
Copy link

Because Mamba is a RNN architecture, the output of the padding words is dirty, what you need to do is just to drop the dirty output and prevent them from inference. I think it can be done in many various ways in pytorch. The padding should be on the right.

It really works for me.

But for performance reasons, I would like to use bidirectional mamba, which includes two scanning processes from left to right and from right to left. Is it not possible to directly fill on the right side?

@ZJEast
Copy link

ZJEast commented Mar 19, 2024

[1, 2, 3] -> [1, 2, 3, 3, 2, 1] ???
maybe you can try this preprocessing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants