Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any example of online inference of S4 block? #158

Open
traidn opened this issue Dec 27, 2024 · 1 comment
Open

Any example of online inference of S4 block? #158

traidn opened this issue Dec 27, 2024 · 1 comment

Comments

@traidn
Copy link

traidn commented Dec 27, 2024

Are there any examples of how to infer s4 block in recurrent mode? I tried using the step function, but it gives errors. I'm attaching my script. What could be the problem?

import torch
from s4 import S4
from sashimi import ResidualBlock

def s4_block(dim):
    layer = S4(
        d_model=dim,
        d_state=16,
        bidirectional=False,
        dropout=0.0,
        transposed=True,
    )
    return ResidualBlock(
        d_model=dim,
        layer=layer,
        dropout=0.0,
    )

model = s4_block(16)
for module in model.modules():
    if hasattr(module, 'setup_step'): module.setup_step(mode="diagonal")
model.eval()

input_seg = torch.randn(1, 16, 100)

full_out, _ = model(input_seg)
print(full_out)

s4_state = model.default_state()
stream_res = []
for i in range(input_seg.shape[-1]):
    part_input = input_seg[:, :, i]
    print(part_input.shape)
    part_res, s4_state = model.step(part_input, s4_state)
    stream_res.append(part_res)

stream_res = torch.cat(stream_res, dim=2)
print(stream_res)
print(torch.allclose(full_out, stream_res))
@traidn traidn changed the title Any example of online inferense of S4 block? Any example of online inference of S4 block? Dec 27, 2024
@sendeniz
Copy link

sendeniz commented Jan 4, 2025

Hi @traidn, I am working on something similar at the moment. Could you post your error ? I tried running your code. The code runs despite this error at the beginning:

Diagonalization error: tensor(0.2134, grad_fn=<DistBackward0>) Diagonalization error: tensor(0.2134, grad_fn=<DistBackward0>)

Is this the error you are referring to ? If yes you can find in the documentation of sashimi.py:

S4 recurrence mode. Using `diagonal` can speed up generation by 10-20%.
`linear` should be faster theoretically but is slow in practice since it
dispatches more operations (could benefit from fused operations).
Note that `diagonal` could potentially be unstable if the diagonalization is numerically unstable
(although we haven't encountered this case in practice), while `dense` should always be stable.

So setting it to module.setup_step(mode="linear"), instead of diagonal resolve this issues, even-though it is slower than diagonal. Please let me know if this worked for you.

Best Deniz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants