-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Neuron] Adding support for adding/ overriding neuron configuration a…
…nd adding support for neuron model quantization configuration.
- Loading branch information
Harsha Bikki
committed
Sep 4, 2024
1 parent
d331156
commit f4ac863
Showing
8 changed files
with
243 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
# creates XLA hlo graphs for all the context length buckets. | ||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" | ||
# creates XLA hlo graphs for all the token gen buckets. | ||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" | ||
# Quantizes neuron model weight to int8 , | ||
# The default config for quantization is int8 dtype. | ||
os.environ['NEURON_QUANT_DTYPE'] = "s8" | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Create an LLM. | ||
llm = LLM( | ||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
max_num_seqs=8, | ||
# The max_model_len and block_size arguments are required to be same as | ||
# max sequence length when targeting neuron device. | ||
# Currently, this is a known limitation in continuous batching support | ||
# in transformers-neuronx. | ||
# TODO(liangfu): Support paged-attention in transformers-neuronx. | ||
max_model_len=2048, | ||
block_size=2048, | ||
# The device can be automatically detected when AWS Neuron SDK is installed. | ||
# The device argument can be either unspecified for automated detection, | ||
# or explicitly assigned. | ||
device="neuron", | ||
quantization="neuron_quant", | ||
override_neuron_config={ | ||
"cast_logits_dtype": "bfloat16", | ||
}, | ||
tensor_parallel_size=2) | ||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
outputs = llm.generate(prompts, sampling_params) | ||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
from importlib.util import find_spec | ||
from typing import Any, Dict, List, Optional | ||
|
||
from torch.nn import Module | ||
|
||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
|
||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] | ||
|
||
|
||
class NeuronQuantConfig(QuantizationConfig): | ||
"""Int8 Quantization Config class for Neuron Backend.""" | ||
|
||
def __init__( | ||
self, | ||
dequant_dtype: str = "f16", | ||
quantize_method: str = "vector_dynamic", | ||
) -> None: | ||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") | ||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: | ||
raise ValueError( | ||
f"Neuron quantization datatype {self.quant_dtype} is not valid," | ||
f"the quantization datatype should match one of the below types" | ||
f"{SUPPORTED_QUANT_DTYPE_LIST}") | ||
self.dequant_dtype = dequant_dtype | ||
self.quantize_method = quantize_method | ||
|
||
def get_name(self) -> str: | ||
return "neuron_quant" | ||
|
||
def get_supported_act_dtypes(self) -> List[str]: | ||
return SUPPORTED_QUANT_DTYPE_LIST | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
raise NotImplementedError( | ||
"This function should not be called with Neuron Backend") | ||
|
||
@staticmethod | ||
def get_config_filenames() -> List[str]: | ||
return [] | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": | ||
quantize_method = cls.get_from_keys(config, ["quantize_method"]) | ||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) | ||
return cls(dequant_dtype=dequant_dtype, | ||
quantize_method=quantize_method) | ||
|
||
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: | ||
if find_spec("transformers_neuronx") is not None: | ||
return self.get_quantization_config() | ||
else: | ||
raise NotImplementedError( | ||
"Neuron Quantization is only supported through" | ||
" transformers_neuronx.") | ||
|
||
def get_scaled_act_names(self) -> List[str]: | ||
return [] | ||
|
||
def get_quantization_config(self): | ||
from transformers_neuronx.config import QuantizationConfig | ||
return QuantizationConfig(quant_dtype=self.quant_dtype, | ||
dequant_dtype=self.dequant_dtype, | ||
quantize_method=self.quantize_method) |
Oops, something went wrong.