Taking a gradient of a parameter input into index_update #6091
-
In jax, it appears that by default, when a tensor is assigned to traced parameters, it is copied into the original tensor and no longer traced. Is this the right interpretation? Code to reproduce:
Output:
My desired output:
Is there some way I can achieve the desired output (assuming that's not JAX's intent / this is not a bug)? |
Beta Was this translation helpful? Give feedback.
Answered by
davisyoshida
Mar 17, 2021
Replies: 1 comment 1 reply
-
The line y.at[2:8].set(x) currently has no effect. Presumably you want: z = y.at[2:8].set(x) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
sunilkpai
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The line y.at[2:8].set(x) currently has no effect. Presumably you want:
z = y.at[2:8].set(x)
return z[:5].sum()
Shouldn't the desired sum be 4, not 10 though? Unless I'm misreading the intent.