Skip to content

Commit

Permalink
Add validation for tensor_split size exceeding LLAMA_MAX_DEVICES (ggm…
Browse files Browse the repository at this point in the history
…l-org#820)

* Add validation for tensor_split size exceeding LLAMA_MAX_DEVICES

* reword
  • Loading branch information
eric1932 authored Oct 15, 2023
1 parent f30aa20 commit b501665
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def __init__(
self.tensor_split = tensor_split
self._p_tensor_split = None
if self.tensor_split is not None:
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
raise ValueError(f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}")
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
self._c_tensor_split = FloatArray(
Expand Down

0 comments on commit b501665

Please sign in to comment.