diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3ea03b0b40..a556677478 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -623,9 +623,9 @@ impl Tensor { } Op::Unary(arg, UnaryOp::Silu) => { let sum_grad = grads.or_insert(arg)?; - // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?; - let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?; + let silu_grad = &sigmoid_arg * (1. - *node) + *node; *sum_grad = sum_grad.add(&(&grad * silu_grad)?)? } Op::Elu(arg, alpha) => {