diff --git a/docs/source/use-vllm-as-backend.mdx b/docs/source/use-vllm-as-backend.mdx index 9d4bb863..db1f503c 100644 --- a/docs/source/use-vllm-as-backend.mdx +++ b/docs/source/use-vllm-as-backend.mdx @@ -87,3 +87,115 @@ An optional key `metric_options` can be passed in the yaml file, using the name of the metric or metrics, as defined in the `Metric.metric_name`. In this case, the `codegen_pass@1:16` metric defined in our tasks will have the `num_samples` updated to 16, independently of the number defined by default. + + +## Multi-node vLLM + +It is entirely possible to use vLLM in a multi-node setting. For this, we will use Ray. +In these examples, we will assume that we are on a Slurm cluster where nodes do not have internet access. +Those scripts are heavily inspired by [https://github.com/NERSC/slurm-ray-cluster/](https://github.com/NERSC/slurm-ray-cluster/). + +First you need to start the Ray cluster. It has one master, and you need to find an available +port on this node + +```bash +function find_available_port { + printf $(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') +} + +PORT=$(find_available_port) +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) +MASTER=${NODELIST[0]} # Name of master node +MASTER_IP=$(hostname --ip-address) # IP address of master node + +function set_VLLM_HOST_IP { + export VLLM_HOST_IP=$(hostname --ip-address); +} +export -f set_VLLM_HOST_IP; + +# Start the master +srun -N1 -n1 -c $(( SLURM_CPUS_PER_TASK/2 )) -w $MASTER bash -c "set_VLLM_HOST_IP; ray start --head --port=$PORT --block" & +sleep 5 + +# Start all other nodes :) +if [[ $SLURM_NNODES -gt 1 ]]; then + srun -N $(( SLURM_NNODES-1 )) --ntasks-per-node=1 -c $(( SLURM_CPUS_PER_TASK/2 )) -x $MASTER bash -c "set_VLLM_HOST_IP; ray start --address=$MASTER_IP:$PORT --block" & + sleep 5 +fi +``` + +Then, once the Ray cluster is running, you can launch vLLM through lighteval. + +```bash +set_VLLM_HOST_IP + +MODEL_ARGS="pretrained=$MODEL_DIRECTORY,gpu_memory_utilisation=0.5,trust_remote_code=False,dtype=bfloat16,max_model_length=16384,tensor_parallel_size=" +TASK_ARGS="community|gpqa-fr|0|0,community|ifeval-fr|0|0" + +srun -N1 -n1 -c $(( SLURM_CPUS_PER_TASK/2 )) --overlap -w $MASTER lighteval vllm "$MODEL_ARGS" "$TASK_ARGS" --custom-tasks $TASK_FILE +``` + +The full script is available here: [https://github.com/huggingface/lighteval/blob/main/examples/slurm/multi_node_vllm.slurm](https://github.com/huggingface/lighteval/blob/main/examples/slurm/multi_node_vllm.slurm) + +## With vLLM Serve + +It is also possible to use the vLLM serve command to achieve a similar result. +It has the following benefits: can be queried by multiple jobs, can be launched only once when needing multiple evaluation, +has lower peak memory on rank 0. + +We also need to start the Ray cluster, the exact same way as before. However, now, +before calling lighteval, we need to start our vllm server. + +```bash +MODEL_NAME="Llama-3.2-1B-Instruct" +SERVER_PORT=$(find_available_port) +export OPENAI_API_KEY="I-love-vLLM" +export OPENAI_BASE_URL="http://localhost:$SERVER_PORT/v1" + +vllm serve $MODEL_DIRECTORY \ + --served-model-name $MODEL_NAME \ + --api-key $OPENAI_API_KEY \ + --enforce-eager \ + --port $SERVER_PORT \ + --tensor-parallel-size \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --gpu-memory-utilization 0.8 \ + --disable-custom-all-reduce \ + 1>vllm.stdout 2>vllm.stderr & +``` + +In my case, I want the evaluation to be done within the same job, hence why vllm serve +is put in the background. Therefore, we need to wait until it is up & running before +launching lighteval + +```bash +ATTEMPT=0 +DELAY=5 +MAX_ATTEMPTS=60 # Might need to be increased in case of very large model +until curl -s -o /dev/null -w "%{http_code}" $OPENAI_BASE_URL/models -H "Authorization: Bearer $OPENAI_API_KEY" | grep -E "^2[0-9]{2}$"; do + ATTEMPT=$((ATTEMPT + 1)) + echo "$ATTEMPT attempts" + if [ "$ATTEMPT" -ge "$MAX_ATTEMPTS" ]; then + echo "Failed: the server did not respond any of the $MAX_ATTEMPTS requests." + exit 1 + fi + echo "vllm serve is not ready yet" + sleep $DELAY +done +``` + +Finally, the above script only finishes when vllm serve is ready, so we can launch +the evaluation. + +```bash +export TOKENIZER_PATH="$MODEL_DIRECTORY" + +TASK_ARGS="community|gpqa-fr|0|0,community|ifeval-fr|0|0" + +lighteval endpoint openai "$MODEL_NAME" "$TASK_ARGS" --custom-tasks $TASK_FILE +``` + +The full script is available here: +[https://github.com/huggingface/lighteval/blob/main/examples/slurm/multi_node_vllm_serve.slurm](https://github.com/huggingface/lighteval/blob/main/examples/slurm/multi_node_vllm_serve.slurm) + diff --git a/examples/slurm/multi_node_vllm.slurm b/examples/slurm/multi_node_vllm.slurm new file mode 100644 index 00000000..ddc19393 --- /dev/null +++ b/examples/slurm/multi_node_vllm.slurm @@ -0,0 +1,57 @@ +#! /bin/bash + +#SBATCH --job-name=EVALUATE_Llama-3.2-1B-Instruct +#SBATCH --account=brb@h100 +#SBATCH --output=evaluation.log +#SBATCH --error=evaluation.log +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:4 +#SBATCH --hint=nomultithread +#SBATCH --constraint=h100 +#SBATCH --time=02:00:00 +#SBATCH --exclusive +#SBATCH --parsable + +set -e +set -x + +module purge +module load miniforge/24.9.0 +conda activate $WORKDIR/lighteval-h100 + +function find_available_port { + printf $(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') +} + +PORT=$(find_available_port) +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) +MASTER=${NODELIST[0]} +MASTER_IP=$(hostname --ip-address) + +export HF_HOME=$WORKDIR/HF_HOME +export MODEL_DIRECTORY=$WORKDIR/HuggingFace_Models/meta-llama/Llama-3.2-1B-Instruct +export TASK_FILE=$WORKDIR/community_tasks/french_eval.py +export HF_HUB_OFFLINE=1 +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +function set_VLLM_HOST_IP { + export VLLM_HOST_IP=$(hostname --ip-address); +} +export -f set_VLLM_HOST_IP; + +srun -N1 -n1 -c $(( SLURM_CPUS_PER_TASK/2 )) -w $MASTER bash -c "set_VLLM_HOST_IP; ray start --head --port=$PORT --block" & +sleep 5 + +if [[ $SLURM_NNODES -gt 1 ]]; then + srun -N $(( SLURM_NNODES-1 )) --ntasks-per-node=1 -c $(( SLURM_CPUS_PER_TASK/2 )) -x $MASTER bash -c "set_VLLM_HOST_IP; ray start --address=$MASTER_IP:$PORT --block" & + sleep 5 +fi + +set_VLLM_HOST_IP + +MODEL_ARGS="pretrained=$MODEL_DIRECTORY,gpu_memory_utilisation=0.5,trust_remote_code=False,dtype=bfloat16,max_model_length=8192,tensor_parallel_size=8" +TASK_ARGS="community|gpqa-fr|0|0,community|ifeval-fr|0|0" + +srun -N1 -n1 -c $(( SLURM_CPUS_PER_TASK/2 )) --overlap -w $MASTER lighteval vllm "$MODEL_ARGS" "$TASK_ARGS" --custom-tasks $TASK_FILE diff --git a/examples/slurm/multi_node_vllm_serve.slurm b/examples/slurm/multi_node_vllm_serve.slurm new file mode 100644 index 00000000..2b4a12eb --- /dev/null +++ b/examples/slurm/multi_node_vllm_serve.slurm @@ -0,0 +1,90 @@ +#! /bin/bash + +#SBATCH --job-name=EVALUATE_Llama-3.2-1B-Instruct +#SBATCH --account=brb@h100 +#SBATCH --output=evaluation.log +#SBATCH --error=evaluation.log +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:4 +#SBATCH --hint=nomultithread +#SBATCH --constraint=h100 +#SBATCH --time=02:00:00 +#SBATCH --exclusive +#SBATCH --parsable + +set -e +set -x + +module purge +module load miniforge/24.9.0 +conda activate $WORKDIR/lighteval-h100 + +function find_available_port { + printf $(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') +} + +PORT=$(find_available_port) +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) +MASTER=${NODELIST[0]} +MASTER_IP=$(hostname --ip-address) + +export HF_HOME=$WORKDIR/HF_HOME +export MODEL_DIRECTORY=$WORKDIR/HuggingFace_Models/meta-llama/Llama-3.2-1B-Instruct +export TASK_FILE=$WORKDIR/community_tasks/french_eval.py +export HF_HUB_OFFLINE=1 +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +function set_VLLM_HOST_IP { + export VLLM_HOST_IP=$(hostname --ip-address); +} +export -f set_VLLM_HOST_IP; + +srun -N1 -n1 -c $(( SLURM_CPUS_PER_TASK/2 )) -w $MASTER bash -c "set_VLLM_HOST_IP; ray start --head --port=$PORT --block" & +sleep 5 + +if [[ $SLURM_NNODES -gt 1 ]]; then + srun -N $(( SLURM_NNODES-1 )) --ntasks-per-node=1 -c $(( SLURM_CPUS_PER_TASK/2 )) -x $MASTER bash -c "set_VLLM_HOST_IP; ray start --address=$MASTER_IP:$PORT --block" & + sleep 5 +fi + +set_VLLM_HOST_IP + +MODEL_NAME="Llama-3.2-1B-Instruct" +SERVER_PORT=$(find_available_port) +export OPENAI_API_KEY="I-love-vllm-serve" +export OPENAI_BASE_URL="http://localhost:$SERVER_PORT/v1" + +vllm serve $MODEL_DIRECTORY \ + --served-model-name $MODEL_NAME \ + --api-key $OPENAI_API_KEY \ + --enforce-eager \ + --port $SERVER_PORT \ + --tensor-parallel-size 8 \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --gpu-memory-utilization 0.8 \ + --disable-custom-all-reduce \ + 1>vllm.stdout 2>vllm.stderr & + + +ATTEMPT=0 +DELAY=5 +MAX_ATTEMPTS=60 +until curl -s -o /dev/null -w "%{http_code}" $OPENAI_BASE_URL/models -H "Authorization: Bearer $OPENAI_API_KEY" | grep -E "^2[0-9]{2}$"; do + ATTEMPT=$((ATTEMPT + 1)) + echo "$ATTEMPT attempts" + if [ "$ATTEMPT" -ge "$MAX_ATTEMPTS" ]; then + echo "Failed: the server did not respond any of the $MAX_ATTEMPTS requests." + exit 1 + fi + echo "vllm serve is not ready yet" + sleep $DELAY +done + +export TOKENIZER_PATH="$MODEL_DIRECTORY" + +TASK_ARGS="community|gpqa-fr|0|0,community|ifeval-fr|0|0" + +lighteval endpoint openai "$MODEL_NAME" "$TASK_ARGS" --custom-tasks $TASK_FILE diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index 9dd497e4..07a3456e 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -107,7 +107,8 @@ def __init__(self, config: OpenAIModelConfig, env_config) -> None: try: self._tokenizer = tiktoken.encoding_for_model(self.model) except KeyError: - self._tokenizer = AutoTokenizer.from_pretrained(self.model) + tokenizer_path = os.environ.get("TOKENIZER_PATH", self.model) + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.pairwise_tokenization = False def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index c606c04e..50ba2d6c 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -93,6 +93,7 @@ class VLLMModelConfig: ) pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together. generation_parameters: GenerationParameters = None # sampling parameters to use for generation + enforce_eager: bool = False # whether or not to disable cuda graphs with vllm subfolder: Optional[str] = None @@ -137,13 +138,19 @@ def tokenizer(self): return self._tokenizer def cleanup(self): - destroy_model_parallel() + if ray is not None: + ray.get(ray.remote(destroy_model_parallel).remote()) + else: + destroy_model_parallel() if self.model is not None: del self.model.llm_engine.model_executor.driver_worker self.model = None gc.collect() ray.shutdown() - destroy_distributed_environment() + if ray is not None: + ray.get(ray.remote(destroy_distributed_environment).remote()) + else: + destroy_distributed_environment() torch.cuda.empty_cache() @property @@ -183,6 +190,7 @@ def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) -> "max_model_len": self._max_length, "swap_space": 4, "seed": 1234, + "enforce_eager": config.enforce_eager, } if int(config.data_parallel_size) > 1: self.model_args["distributed_executor_backend"] = "ray"