Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Prototype of vLLM execution on CPU-only devices #1028

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def main(args: argparse.Namespace):
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
device=args.device,
enforce_eager=args.enforce_eager,
swap_space=args.swap_space,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -127,5 +129,15 @@ def run_to_completion(profile_dir: Optional[str] = None):
'path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'
))
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument("--swap-space",
type=int,
default=4,
help="memory space available for CPU (GB).")
args = parser.parse_args()
main(args)
30 changes: 20 additions & 10 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def run_vllm(
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
device: str,
max_model_len: Optional[int],
enforce_eager: bool,
swap_space: int,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -83,6 +85,8 @@ def run_vllm(
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
device=device,
swap_space=swap_space,
)

# Add the requests to the engine.
Expand All @@ -109,21 +113,16 @@ def run_vllm(
return end - start


def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
def run_hf(requests: List[Tuple[str, int, int]], model: str,
tokenizer: PreTrainedTokenizerBase, n: int, use_beam_search: bool,
max_batch_size: int, trust_remote_code: bool) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token

llm = llm.cuda()

pbar = tqdm(total=len(requests))
Expand Down Expand Up @@ -206,7 +205,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager)
args.device, args.max_model_len, args.enforce_eager,
args.swap_space)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -284,6 +284,16 @@ def main(args: argparse.Namespace):
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument("--swap-space",
type=int,
default=4,
help="memory space available for CPU (GB).")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
77 changes: 77 additions & 0 deletions cpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
FROM python:3.10 AS dev

RUN apt-get update -y \
&& apt-get install -y python3-pip

WORKDIR /workspace

# install build and runtime dependencies
COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-cpu.txt

# install development dependencies
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt

# image to build pytorch extensions
FROM dev AS build

# install build dependencies
COPY requirements-build-cpu.txt requirements-build-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-build-cpu.txt

# copy input files
COPY csrc csrc
COPY setup.py setup.py
COPY requirements-cpu.txt requirements-cpu.txt
COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py

# max jobs used by Ninja to build extensions
ENV MAX_JOBS=$max_jobs
RUN python3 setup.py build_ext --inplace

# image to run unit testing suite
FROM dev AS test

# copy pytorch extensions separately to avoid having to rebuild
# when python code changes
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY tests tests
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "pytest", "tests"]

# use CUDA base as CUDA runtime dependencies are already installed via pip
FROM python:3.10 AS dev

# libnccl required for ray
RUN apt-get update -y \
&& apt-get install -y python3-pip

WORKDIR /workspace
COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-cpu.txt

FROM vllm-base AS vllm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-base is not available if build cpu.Dockerfile only

COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

EXPOSE 8000
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]

# openai api server alternative
FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate fschat

COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

2 changes: 1 addition & 1 deletion csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ void gather_cached_kv(
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
torch::Tensor& slot_mapping);
54 changes: 54 additions & 0 deletions csrc/cpu/activation_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "cpu_types.hpp"

namespace {
template <typename scalar_t>
void silu_and_mul_cpu_impl(int num_tokens, int d, scalar_t *__restrict__ input,
scalar_t *__restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();

TORCH_CHECK(d % VEC_ELEM_NUM == 0);

const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);

#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
const int start = i * 2 * d;
const scalar_vec_t x(input + start + j);
const scalar_vec_t y(input + start + d + j);

const vec_op::FP32Vec8 f32_x(x.reg);
const vec_op::FP32Vec8 f32_y(y.reg);

const vec_op::FP32Vec8 f32_ans =
f32_y * (f32_x / (ones + (zeros - f32_x).exp()));

const scalar_vec_t ans(f32_ans.reg);
ans.save(output + i * d + j);
}
}
}
}; // namespace

void silu_and_mul_cpu(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_cpu_impl)
silu_and_mul_cpu_impl(num_tokens, d, input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_cpu_impl)
});
}

void gelu_new_cpu(torch::Tensor &out, torch::Tensor &input) {
TORCH_CHECK(false, "gelu_new is unsupported on CPU.")
}

void gelu_fast_cpu(torch::Tensor &out, torch::Tensor &input) {
TORCH_CHECK(false, "gelu_fast is unsupported on CPU.")
}
Loading