Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider helpers for accessing tensors from tuples and input wrappers #878

Closed
swfsql opened this issue Oct 27, 2023 · 1 comment
Closed

Comments

@swfsql
Copy link
Contributor

swfsql commented Oct 27, 2023

Disclaimer: I'm a beginner on AI and am currently trying to implement a Unet as a study.
Also please consider this a draft idea.

The overall goal would be for the Sequential generation to be used in more cases if a Module can first access a tensor from a set (or tuple) of tensors.

For example:

#[derive(Default, Clone, Sequential)]
pub struct ResidualAdd2<T: Clone + std::fmt::Debug> {
    // input = Input
    pub split: SplitInto<(Id, Id)>,
    // input = (Input, Input)
    pub t: On<tuple::_0, T>, // access the first slot from the tuple and then pass it to the `T` module
    // input = (T::Output, Input)
    pub add: Add,
    // input = T::Output = Input
}

note: Id just forwards the input as-is, and Add just calls TryAdd for both tensors.

I'm not sure if this would be a good way to go, but the On module would apply the T module on the first tensor from the input flow, on the first from the tuple of two tensors.

I'm not sure if this is true, but by avoiding inserting more layering type information directly into the split field, we may be able to make use of some Modules that would otherwise need to be recursive. Although I wouldn't really be too happy with dealing with tuple indexes all around the Architecture.

If this direction has some worth in it, then maybe it would also be better for the tensors to be named and stored in structures, and maybe have some access derivation, such as:

#[input_wrapper] // generates a `mod split { .. }`
pub struct Split<Forward, Skip> {
    pub forward: Forward,
    pub skip: Skip,
}

#[derive(Default, Clone, Sequential)]
pub struct ResidualAdd2<T: Clone + std::fmt::Debug> {
    // input = Input
    pub split: SplitInto<(Id, Id)>,
    // input = (Input, Input)
    pub input_to_wrapper: split::FromTuple, // converts from (A, B) into Split<A, B>
    // input = Split<Input, Input>
    pub t: On<split::forward, T>, // access the field `forward` and then pass it to the `T` module
    // input = Split<T::Output, Input>
    pub input_to_tuple: split::IntoTuple, // converts from Split<A, B> into (A, B)
    // input = (T::Output, Input)
    pub add: Add,
    // input = T::Output = Input
}

Where in this case, the effects would be the same but raw tuple indexes would no longer be used. The "access" concepts I imagined something inspired on how druid works (as far as generating a module with structs representing each field goes), although it's not clear whether going that way pays it off.

Thanks for reading! Please also consider this a draft idea.
Edit: will try to create a draft PR containing the macro attr derive for this.

@swfsql swfsql changed the title Consider helpers for accessing tensors from tuples Consider helpers for accessing tensors from tuples and input wrappers Nov 4, 2023
@swfsql swfsql mentioned this issue Nov 6, 2023
@swfsql
Copy link
Contributor Author

swfsql commented Mar 1, 2024

As stated in the PR, I don't think this is a general demanded case and would make more sense to be an external library experiment, but feel free to ping in case anyone have some question or suggestion.

@swfsql swfsql closed this as completed Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant