Skip to content

Commit

Permalink
Update-tp test (#35844)
Browse files Browse the repository at this point in the history
* update test for now

* up

* cleanup

* update todo
  • Loading branch information
ArthurZucker authored Feb 3, 2025
1 parent 62db3e6 commit 7eecdf2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)


# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
Expand Down
46 changes: 34 additions & 12 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tempfile
import textwrap

# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
Expand Down Expand Up @@ -110,9 +111,8 @@ def test_loading_memory_consumption(self):

# Test settings
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bs = 4
seqlen = 64

bs = 1
seqlen = 4096
# Get distributed settings
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -124,23 +124,45 @@ def test_loading_memory_consumption(self):

# Get model config
config = LlamaConfig.from_pretrained(model_id)
# Shrink model size
config.num_hidden_layers //= 8
config.vocab_size //= 8

config.hidden_size = 2048
config.attention_bias = False
# Instantiate model
with device:
model = LlamaModel(config)
model = LlamaModel(config).to(dtype=torch.float16)

model.eval()

# Tensor Parallel
if world_size > 1:
model.tensor_parallel(device_mesh)

# Run model

inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
with torch.no_grad():
out = model(inputs)

# Test cuda graphing explicitly
with torch.cuda.device(device):
print("Cuda graphing")
with torch.no_grad():
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
# CUDA Graph setup
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out = model(inputs)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(inputs)

for _ in range(2):
g.replay()
s.synchronize()

assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])

# Test compile
with torch.no_grad():
out = model(inputs)
model.forward = torch.compile(model.forward, mode="reduce-overhead")
out = model(inputs)
out = model(inputs)

0 comments on commit 7eecdf2

Please sign in to comment.