Skip to content

Commit

Permalink
Fix inference
Browse files Browse the repository at this point in the history
  • Loading branch information
wilson1yan committed Aug 7, 2024
1 parent 3778ac1 commit b8e3602
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion gpu_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ optax==0.2.2
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.0.0
transformers==4.40.0
ringattention==0.1.1
ringattention @ git+https://github.com/forhaoliu/ringattention.git
datasets
einops
tqdm
Expand Down
29 changes: 14 additions & 15 deletions lwm/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ml_collections import ConfigDict
from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh
from ringattention import ringattention, blockwise_feedforward, ringattention_jax
from ringattention import ringattention, blockwise_feedforward, ringattention_jax, ringattention_inference


LLAMA_STANDARD_CONFIGS = {
Expand Down Expand Up @@ -579,8 +579,11 @@ def __call__(
segment_mask = None
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :]
segment_mask = segment_mask[:, None]
if segment_ids is not None:
segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :]
segment_mask = segment_mask[:, None]
else:
segment_mask = None

batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
Expand All @@ -596,7 +599,7 @@ def __call__(
q_sp_dim = None if xq.shape[1] == 1 else 'sp'
attn_weights = None
ring_attention_sharded = shard_map(
partial(ringattention_jax, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
partial(ringattention_inference, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
in_specs=(
PS(("dp", "fsdp"), q_sp_dim, "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
Expand Down Expand Up @@ -667,9 +670,9 @@ class FlaxLLaMABlock(nn.Module):
def setup(self) -> None:
attention_module = FlaxLLaMAAttention
mlp_module = FlaxLLaMAMLP
if self.scan_map:
if self.config.scan_mlp:
mlp_module = remat(
self.feed_forward, static_argnums=(1,),
mlp_module, static_argnums=(1,),
policy=jax.checkpoint_policies.nothing_saveable,
prevent_cse=not self.config.scan_layers,
)
Expand Down Expand Up @@ -767,7 +770,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
segment_ids = jnp.zeros_like(input_ids)
segment_ids = None
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
Expand All @@ -786,7 +789,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz
return_dict=False,
)
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, segment_ids, position_ids, return_dict=False)

random_params = module_init_outputs["params"]

Expand All @@ -812,13 +815,13 @@ def init_cache(self, batch_size, max_length):
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
segment_ids = jnp.zeros_like(input_ids)
segment_ids = None
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"].unfreeze()
return init_variables["cache"]

@add_start_docstrings_to_model_forward("")
def __call__(
Expand Down Expand Up @@ -851,8 +854,6 @@ def __call__(

if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
if segment_ids is None:
segment_ids = jnp.zeros((batch_size, sequence_length))

# Handle any PRNG if needed
rngs = {}
Expand All @@ -871,7 +872,7 @@ def __call__(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(segment_ids, dtype="i4"),
segment_ids,
jnp.array(position_ids, dtype="i4"),
not train,
False,
Expand Down Expand Up @@ -1076,8 +1077,6 @@ def __call__(
batch_size, seq_length = input_ids.shape
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if segment_ids is None:
segment_ids = jnp.zeros_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(seq_length, dtype=jnp.int32)[None].repeat(batch_size, axis=0)
outputs = self.transformer(
Expand Down
2 changes: 1 addition & 1 deletion tpu_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ optax==0.2.2
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.0.0
transformers==4.40.0
ringattention==0.1.1
ringattention @ git+https://github.com/forhaoliu/ringattention.git
datasets
einops
tqdm
Expand Down

0 comments on commit b8e3602

Please sign in to comment.