Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split TryConcatAlong into different traits #892

Merged
merged 4 commits into from
Dec 4, 2023

Conversation

swfsql
Copy link
Contributor

@swfsql swfsql commented Nov 17, 2023

Closes #891.

  • Added a failing test.
    • Unsure where to add the test. It's about a Module, so it should not be in dfdx_core. But there's no Module in dfdx::nn representing tensor concatenation, so I've added as an integration test.
    • The test was later adjusted to the new trait, being able to run successfully.
  • Deprecated TryConcatAlong in favor of TryConcatTensorAlong or TryConcatShapeAlong.
  • Created concat_tensor_along/ and concat_shape_along/.
    • Copied relevant sections and files from concat_along, adjusting where necessary.
    • Moved concat_along/ kernels to concat_tensor_along/.

The added test, that initially failed to compile:

#[test]
fn test_issue_891() {
#[derive(Default, Debug, Clone, Copy, CustomModule)]
pub struct Id;
impl<Input> Module<Input> for Id {
type Output = Input;
fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
Ok(x)
}
}
#[derive(Default, Debug, Clone, Copy, dfdx_derives::CustomModule)]
struct ConcatTensorAlong<Ax: Axes<Array = [isize; 1]> + Debug>(pub Ax);
impl<Input, const AXIS: isize> Module<Input> for ConcatTensorAlong<Axis<AXIS>>
where
Input: TryConcatAlong<Axis<AXIS>>,
{
type Output = <Input as TryConcatAlong<Axis<AXIS>>>::Output;
fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
x.try_concat_along(Axis)
}
}
type Arch = (SplitInto<(Id, Id)>, ConcatTensorAlong<Axis<0>>);
let dev = Cpu::default();
let x = dev.tensor([1.]);
let m = dev.build_module::<f32>(Arch::default());
let y = m.forward(x);
/*
error[E0275]: overflow evaluating the requirement `((_, _, _, _), (..., ..., ..., ...)): dfdx::prelude::TryConcatAlong<...>`
--> dfdx/tests/issue_tests.rs:36:15
|
36 | let y = m.forward(x);
| ^^^^^^^
|
*/
}

@swfsql swfsql marked this pull request as ready for review November 17, 2023 02:06
Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tiny changes, but otherwise looks good and I was able to verify that the current main doesn't compile the test and the fix does 👍

@coreylowman
Copy link
Owner

@swfsql looks good - do you mind merging with the webgpu changes I just merged? Should be pretty straightforward 🤞 No need to test with webgpu features yet

swfsql and others added 4 commits December 3, 2023 19:04
- Deprecated `TryConcatAlong` in favor of `TryConcatTensorAlong` or `TryConcatShapeAlong`.
- Created `concat_tensor_along/` and `concat_shape_along/`.
  - Copied relevant sections and files from `concat_along`, adjusting where necessary.
  - Moved `concat_along/` kernels to `concat_tensor_along/`.
- Adjusted the issue's integration test to the new trait, which runs successfully.
@swfsql
Copy link
Contributor Author

swfsql commented Dec 4, 2023

@coreylowman sure np, I've rebased and basically just moved the webgpu kernel to where the others are.

I've made one change, remaking this item as public so it's the same behavior as from before.

pub use super::concat_tensor_along::ConcatAlongKernel;

@coreylowman coreylowman merged commit beee7a1 into coreylowman:main Dec 4, 2023
4 checks passed
@coreylowman
Copy link
Owner

Woohoo, thanks for this change! 🎉

@swfsql swfsql deleted the issue-891 branch March 1, 2024 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Split TryConcatAlong into different traits
2 participants