From 82c314b553c57c4a5927627e341138d9f18c23d5 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 4 Dec 2023 17:57:24 -0500 Subject: [PATCH] updates to current main --- dfdx-derives/src/lib.rs | 20 ++++++++++---------- dfdx/src/nn/layers/on.rs | 7 +++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 442fb690..c0782fdb 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -1167,10 +1167,10 @@ pub fn input_wrapper( ); quote! { #[doc = #doc1] - #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::CustomModule)] pub struct FromTuple; #[doc = #doc2] - #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule)] + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, ::dfdx::prelude::CustomModule)] pub struct IntoTuple; } }; @@ -1235,16 +1235,16 @@ pub fn input_wrapper( let doc2 = format!("Module to convert a [`{}`] into a tuple.", wrapper_ident,); quote! { #[doc = #doc1] - impl<#(#wrapper_generic_names), *> crate::prelude::Module<(#(#field_ty_names), *)> for FromTuple { + impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<(#(#field_ty_names), *)> for FromTuple { type Output = #wrapper_ident<#(#wrapper_generic_names), *>; - fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result { + fn try_forward(&self, x: (#(#field_ty_names), *)) -> Result { Ok(x.into()) } } #[doc = #doc2] - impl<#(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { + impl<#(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for IntoTuple { type Output = (#(#field_ty_names), *); - fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { Ok(x.into()) } } @@ -1295,7 +1295,7 @@ pub fn input_wrapper( ); } contains_ident = true; - quote!(>::Output) + quote!(>::Output) } else { quote!(#ty_ident) } @@ -1334,15 +1334,15 @@ pub fn input_wrapper( let field_access_module = quote! { #[doc = #doc] - impl, #(#wrapper_generic_names), *> crate::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for crate::prelude::On<#on_acccess, M> { + impl, #(#wrapper_generic_names), *> ::dfdx::prelude::Module<#wrapper_ident<#(#wrapper_generic_names), *>> for ::dfdx::prelude::On<#on_acccess, M> { type Output = #wrapper_ident<#(#output_generics), *>; - fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward(&self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { #(#field_extraction)* let #forward = self.t.try_forward(#forward)?; let x = #field_replacement; Ok(x) } - fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { + fn try_forward_mut(&mut self, x: #wrapper_ident<#(#wrapper_generic_names), *>) -> Result { #(#field_extraction)* let #forward = self.t.try_forward_mut(#forward)?; let x = #field_replacement; diff --git a/dfdx/src/nn/layers/on.rs b/dfdx/src/nn/layers/on.rs index 9d0465cb..6396d70c 100644 --- a/dfdx/src/nn/layers/on.rs +++ b/dfdx/src/nn/layers/on.rs @@ -4,13 +4,12 @@ use std::marker::PhantomData; // TODO: try making a Call module, whih allows calling an arbitrary method on the input. /// Applies module `T` into an input field from a wrapper. -#[derive( - Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors, -)] +#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))] +#[derive(Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams)] #[repr(transparent)] pub struct On { #[module] - #[serialize] + #[cfg_attr(feature = "safetensors", serialize)] pub t: T, pub _n: PhantomData,