Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference on TPU Pod (v4-64) #183

Closed
creatorrr opened this issue Jan 7, 2025 · 10 comments
Closed

Inference on TPU Pod (v4-64) #183

creatorrr opened this issue Jan 7, 2025 · 10 comments

Comments

@creatorrr
Copy link

creatorrr commented Jan 7, 2025

Describe the bug
I am trying to run inference on a TPU Pod v4-64 (8 workers, 32 devices). I tried running the following code from one of the examples:

Am a complete noob at both TPUs and easydel so maybe I am missing something really obvious? Plz halp!

dp, fsdp, tp, sp = 1, 1, 8, 4
sharding_axis_dims = (dp, fsdp, tp, sp)

# We have 32 devices in total
num_devices = len(jax.devices())
print("Number of JAX devices:", num_devices)

max_length = 6144

pretrained_model_name_or_path = "NaniDAO/Meta-Llama-3.1-8B-Instruct-ablated-v1"
dtype = jnp.bfloat16

# Create partition_axis telling EasyDel how to slice each dimension
partition_axis = ed.PartitionAxis(
    # batch_axis="dp",       # Use dp to shard the batch dimension
    # head_axis="tp",        # Use tp to shard the heads dimension
    # query_sequence_axis=None,  # or "sp" if you want sequence parallel on queries
    # key_sequence_axis=None,    # or "sp" if you want sequence parallel on keys
)

# Build model with the desired parallelism
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    auto_shard_model=True,
    sharding_axis_dims=sharding_axis_dims,
    config_kwargs=ed.EasyDeLBaseConfigDict(
        use_scan_mlp=False,             # or True if you want scanning MLP
        partition_axis=partition_axis,
        attn_dtype=jnp.bfloat16,
        freq_max_position_embeddings=max_length,
        mask_max_position_embeddings=max_length,
        attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2,
    ),
    quantization_method="8bit",
    platform=None,
    partition_axis=partition_axis,
    param_dtype=dtype,
    dtype=dtype,
    precision=lax.Precision("fastest"),
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id

# Build inference object
inference = ed.vInference(
    model=model,
    processor_class=tokenizer,
    generation_config=ed.vInferenceConfig(
        max_new_tokens=1024,
        temperature=model.generation_config.temperature,
        top_p=model.generation_config.top_p,
        top_k=model.generation_config.top_k,
        eos_token_id=model.generation_config.eos_token_id,
        streaming_chunks=64,
    ),
)

# Precompile with batch_size=1
inference.precompile(1)
print("Inference name:", inference.inference_name)

Getting this error:

expected (batch_size=1, num_heads=4, q_seq_len=5120, kv_seq_len=6144), got (1, 1, 5120, 6144)

Traceback:

    raise ValueError(
ValueError: Attention bias shape mismatch: expected (batch_size=1, num_heads=4, q_seq_len=5120, kv_seq_len=6144), got (1, 1, 5120, 6144)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/diwank/exp/llama_example.py", line 83, in main
    inference.precompile(1)
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/inference/vinference/vinference.py", line 540, in precompile
    self._compile_and_lower_funs(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/inference/vinference/vinference.py", line 472, in _compile_and_lower_funs
    causal_lm_first_iter_fn.lower(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/inference/vinference/_fn.py", line 89, in causal_lm_first_iter_fn
    state = create_sampling_step(
  File "/home/diwank/.local/lib/python3.10/site-packages/fjformer/core/implicit_array.py", line 499, in implicit_f
    outs_flat = f_wrapped.call_wrapped(*flat_args)
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/inference/utils.py", line 297, in sampling_step
    model_outputs = model(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/modules/llama/modeling_llama_flax.py", line 538, in __call__
    outputs = self.model(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/modules/llama/modeling_llama_flax.py", line 441, in __call__
    layer_outputs = block(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/modules/llama/modeling_llama_flax.py", line 320, in __call__
    attn_outputs = self.self_attn(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/modules/llama/modeling_llama_flax.py", line 227, in __call__
    attentions = self.attention_performer(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/layers/attention.py", line 619, in __call__
    return self.flash_attn2(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/layers/attention.py", line 844, in flash_attn2
    attention_outputs = shard_map(
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/kernels/flash_attention_2.py", line 213, in __call__
    return self._compute_pallas(query, key, value, bias)
  File "/home/diwank/.local/lib/python3.10/site-packages/easydel/kernels/flash_attention_2.py", line 312, in _compute_pallas
    return pallas_flash_attention_tpu(
  File "/home/diwank/.local/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 180, in flash_attention
    raise ValueError(
ValueError: Attention bias shape mismatch: expected (batch_size=1, num_heads=4, q_seq_len=5120, kv_seq_len=6144), got (1, 1, 5120, 6144)
@creatorrr creatorrr changed the title Inference on TPU Pod (v4-32) Inference on TPU Pod (v4-64) Jan 7, 2025
@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Hi, and thanks for using easydel
There are some issues with vinference in multi host VMs and soon will be resolved

@creatorrr
Copy link
Author

Appreciate the quick response. :)

I was actually able to run the server but when I tried using the chat completions endpoint, it threw this error:
One of pjit outputs was given the sharding of NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 32, 'tp': 1, 'sp': 1), spec=PartitionSpec(('dp', 'fsdp'), 'sp'), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 32, but it is equal to 1 (full shape: (1, 130048))

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

You have changed fsdp size to 32 so you need at least 32 input batches
You have 2 options first one is use sequence sharding 1,1,1,-1 or combine that with tensor parallel 1,1,4,-1

Second one is pass in 32 batches to vinference

@creatorrr
Copy link
Author

Didn’t work with any configuration combination in the end. :(

Any idea what the issue is for multi node tpu and how to fix it?

@erfanzar
Copy link
Owner

Have you tested with sample inputs?

To get started:

  1. Install EOpod and configure it.
  2. Run the following command:
eopod kill-tpu --force 
eopod run "cd EasyDeL && git pull && USE_AOT=false python tests/vinference_runtime_test.py"
  1. Check for any errors.

I’ve tested this code on TPUv4-32. While I primarily use GPUs and don’t have full sponsorship from TRC, I can confirm that this code works reliably on both 16×A100 GPUs and TPUv4-32.

Also, I might be slow to respond to GitHub issues. If it’s urgent, feel free to DM me on Discord.

Let me know if you encounter any issues

@nathom
Copy link

nathom commented Jan 16, 2025

@erfanzar I've been trying to run inference on TPU v4-32 as well, but am running into issues. When I run the command you gave, I get

F0116 22:21:35.946161  685224 parse_flags_from_env.cc:226] Unknown flags in XLA_FLAGS: --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_pipelined_collectives=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_multi_streamed_windowed_einsum=true --xla_gpu_threshold_for_windowed_einsum_mib=0 --xla_gpu_enable_command_buffer=  --xla_gpu_multi_streamed_windowed_einsum=true --xla_gpu_threshold_for_windowed_einsum_mib=0 --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=
Perhaps you meant to specify these on the TF_XLA_FLAGS envvar?

Any tips? Also, what is your discord?

@erfanzar
Copy link
Owner


Hi @nathom,

Thank you for using EasyDel!

From the errors you're encountering, it seems like you might be using a JAX version lower than 0.4.28. To resolve this, please ensure that you update your JAX installation to the latest version.

If updating JAX doesn't fix the issue, you can disable EasyDel's auto flags by setting the environment variable EASYDEL_AUTO to off or 0. For example:

EASYDEL_AUTO=0 ...script

Let me know if you need further assistance!

@erfanzar
Copy link
Owner

@nathom sorry i forgot about discord search for citifer in discord or join JaxLLM https://discord.gg/KrXruTEy

@nathom
Copy link

nathom commented Jan 19, 2025

That fixes the issue, thanks!

@erfanzar
Copy link
Owner

erfanzar commented Jan 20, 2025

Hello @creatorrr and @nathom,

Thank you for using EasyDeL!

I’m happy to share that the issue with sharding model statics across multiple nodes on TPUs has been resolved and is now fully functional.

Tested on:

  • V4-8
  • V4-16
  • V4-32
  • V4-64
  • V4-256

Let us know if you encounter any further issues or have feedback!

Feel free to re-open the issue if the problem persists or if you encounter any further challenges!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants