Skip to content

Commit

Permalink
Up to 2x speedup on GPUs using memory efficient attention (huggingfac…
Browse files Browse the repository at this point in the history
…e#532)

* 2x speedup using memory efficient attention

* remove einops dependency

* Swap K, M in op instantiation

* Simplify code, remove unnecessary maybe_init call and function, remove unused self.scale parameter

* make xformers a soft dependency

* remove one-liner functions

* change one letter variable to appropriate names

* Remove Env variable dependency, remove MemoryEfficientCrossAttention class and use enable_xformers_memory_efficient_attention method

* Add memory efficient attention toggle to img2img and inpaint pipelines

* Clearer management of xformers' availability

* update optimizations markdown to add info about memory efficient attention

* add benchmarks for TITAN RTX

* More detailed explanation of how the mem eff benchmark were ran

* Removing autocast from optimization markdown

* import_utils: import torch only if is available

Co-authored-by: Nouamane Tazi <[email protected]>
  • Loading branch information
MatthieuToulemont and NouamaneTazi authored Nov 2, 2022
1 parent 7e05c53 commit ab9cf55
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 4 deletions.
55 changes: 51 additions & 4 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
import torch.nn.functional as F
from torch import nn

from diffusers.utils.import_utils import is_xformers_available


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class AttentionBlock(nn.Module):
"""
Expand Down Expand Up @@ -150,6 +159,10 @@ def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape
Expand Down Expand Up @@ -206,6 +219,32 @@ def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers

def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
Expand Down Expand Up @@ -239,6 +278,7 @@ def __init__(
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self._use_memory_efficient_attention_xformers = False

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
Expand Down Expand Up @@ -279,11 +319,13 @@ def forward(self, hidden_states, context=None, mask=None):
# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of

if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)

# linear proj
hidden_states = self.to_out[0](hidden_states)
Expand Down Expand Up @@ -341,6 +383,11 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _memory_efficient_attention_xformers(self, query, key, value):
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class FeedForward(nn.Module):
r"""
Expand Down
12 changes: 12 additions & 0 deletions models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
Expand Down Expand Up @@ -542,6 +546,10 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()

Expand Down Expand Up @@ -1117,6 +1125,10 @@ def set_attention_slice(self, slice_size):

self.gradient_checkpointing = False

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(
self,
hidden_states,
Expand Down
11 changes: 11 additions & 0 deletions models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ def set_attention_slice(self, slice_size):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
Expand Down
18 changes: 18 additions & 0 deletions pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
18 changes: 18 additions & 0 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
18 changes: 18 additions & 0 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
16 changes: 16 additions & 0 deletions utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False

_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
import torch

if torch.__version__ < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -205,6 +217,10 @@ def is_scipy_available():
return _scipy_available


def is_xformers_available():
return _xformers_available


def is_accelerate_available():
return _accelerate_available

Expand Down

0 comments on commit ab9cf55

Please sign in to comment.