Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redundant computations when use_hankel_L=True #3

Open
windsornguyen opened this issue Sep 19, 2024 · 0 comments
Open

Redundant computations when use_hankel_L=True #3

windsornguyen opened this issue Sep 19, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@windsornguyen
Copy link
Member

In stu.py, we are computing the negative filters even when we set use_hankel_L=True.

class STU(nn.Module):
    def __init__(self, config, phi, n) -> None:
        super(STU, self).__init__()
        self.config = config
        self.phi = phi
        self.n = n
        self.K = config.num_eigh
        self.d_in = config.n_embd
        self.d_out = config.n_embd
        self.use_hankel_L = config.use_hankel_L
        self.use_approx = config.use_approx
        self.flash_fft = (
            FlashFFTConv(self.n, dtype=torch.bfloat16)
            if config.use_flash_fft and flash_fft_available
            else None
        )
        if self.use_approx:
            self.M_inputs = nn.Parameter(
                torch.empty(self.d_in, self.d_out, dtype=config.torch_dtype)
            )
            self.M_filters = nn.Parameter(
                torch.empty(self.K, self.d_in, dtype=config.torch_dtype)
            )
        else:
            self.M_phi_plus = nn.Parameter(
                torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype)
            )
            if not self.use_hankel_L:
                self.M_phi_minus = nn.Parameter(
                    torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype)
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_approx:
            # Contract inputs and filters over the K and d_in dimensions, then convolve
            x_proj = x @ self.M_inputs
            phi_proj = self.phi @ self.M_filters
            if self.flash_fft:
                spectral_plus, spectral_minus = flash_convolve(
                    x_proj, phi_proj, self.flash_fft, self.use_approx
                )
            else:
                spectral_plus, spectral_minus = convolve(
                    x_proj, phi_proj, self.n, self.use_approx
                )
        else:
            # Convolve inputs and filters,
            if self.flash_fft:
                U_plus, U_minus = flash_convolve(
                    x, self.phi, self.flash_fft, self.use_approx
                )
            else:
                U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx)
            # Then, contract over the K and d_in dimensions
            spectral_plus = torch.tensordot(
                U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
            )
            if not self.use_hankel_L:
                spectral_minus = torch.tensordot(
                    U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
                )

        return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus

This can be optimized by refactoring stu.py, and the convolve and flash_convolve functions to remove these redundant computations, but it'll takes some careful design choices in order to not mess up anything too badly as the logic can get quite hairy!

@windsornguyen windsornguyen added the enhancement New feature or request label Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant