You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Are there any examples of how to infer s4 block in recurrent mode? I tried using the step function, but it gives errors. I'm attaching my script. What could be the problem?
import torch
from s4 import S4
from sashimi import ResidualBlock
def s4_block(dim):
layer = S4(
d_model=dim,
d_state=16,
bidirectional=False,
dropout=0.0,
transposed=True,
)
return ResidualBlock(
d_model=dim,
layer=layer,
dropout=0.0,
)
model = s4_block(16)
for module in model.modules():
if hasattr(module, 'setup_step'): module.setup_step(mode="diagonal")
model.eval()
input_seg = torch.randn(1, 16, 100)
full_out, _ = model(input_seg)
print(full_out)
s4_state = model.default_state()
stream_res = []
for i in range(input_seg.shape[-1]):
part_input = input_seg[:, :, i]
print(part_input.shape)
part_res, s4_state = model.step(part_input, s4_state)
stream_res.append(part_res)
stream_res = torch.cat(stream_res, dim=2)
print(stream_res)
print(torch.allclose(full_out, stream_res))
The text was updated successfully, but these errors were encountered:
traidn
changed the title
Any example of online inferense of S4 block?
Any example of online inference of S4 block?
Dec 27, 2024
Hi @traidn, I am working on something similar at the moment. Could you post your error ? I tried running your code. The code runs despite this error at the beginning:
Is this the error you are referring to ? If yes you can find in the documentation of sashimi.py:
S4 recurrence mode. Using `diagonal` can speed up generation by 10-20%.
`linear` should be faster theoretically but is slow in practice since it
dispatches more operations (could benefit from fused operations).
Note that `diagonal` could potentially be unstable if the diagonalization is numerically unstable
(although we haven't encountered this case in practice), while `dense` should always be stable.
So setting it to module.setup_step(mode="linear"), instead of diagonal resolve this issues, even-though it is slower than diagonal. Please let me know if this worked for you.
Are there any examples of how to infer s4 block in recurrent mode? I tried using the step function, but it gives errors. I'm attaching my script. What could be the problem?
The text was updated successfully, but these errors were encountered: