Skip to content

Commit

Permalink
memory_efficient_attention_backward: Correctly reshape output grads
Browse files Browse the repository at this point in the history
ghstack-source-id: 5df45ecbfabfa97198e4ea787ad897d46905c277
Pull Request resolved: #593
  • Loading branch information
danthe3rd committed Dec 15, 2022
1 parent 9972795 commit cd6af2d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ def _memory_efficient_attention_backward(
f"out.shape : {ctx.out.shape} \n"
f"query.shape: {inp.query.shape}"
)
shape_dq, shape_dk, shape_dv = tuple(
x.shape for x in (inp.query, inp.key, inp.value)
)
inp.normalize_bmhk()
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
if (
Expand Down Expand Up @@ -374,7 +377,11 @@ def _memory_efficient_attention_backward(
raise ValueError(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)
return op.apply(ctx, inp, grad)
grads = op.apply(ctx, inp, grad)
grads.dq = grads.dq.reshape(shape_dq)
grads.dk = grads.dk.reshape(shape_dk)
grads.dv = grads.dv.reshape(shape_dv)
return grads


__all__ = [
Expand Down

0 comments on commit cd6af2d

Please sign in to comment.