Skip to content

Commit

Permalink
Fix deps, enable 8-bit by default for TP (#298)
Browse files Browse the repository at this point in the history
This PR fixes issues of #290:

- hivemind bfloat16 codec crashed on dummy tensors (with 0 elements), see learning-at-home/hivemind#560 (this PR makes Petals depend on the latest hivemind version from the repo, it's temporary)
- transformers version check mismatched with the version allowed in `setup.cfg`

Also:

- This PR enables 8-bit by default for TP. Even though TP in 8-bit may be slower, we currently prefer to host more blocks to increase the network's stability.
  • Loading branch information
borzunov authored Mar 29, 2023
1 parent 987f4d2 commit 2116df0
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
huggingface-hub==0.11.1
transformers>=4.25.1,<5.0.0
speedtest-cli==2.1.3
hivemind==1.1.6
hivemind @ git+https://github.com/learning-at-home/hivemind.git
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2
Expand Down
4 changes: 2 additions & 2 deletions src/petals/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0"
version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"


class WrappedBloomBlock(BloomBlock):
Expand Down
6 changes: 0 additions & 6 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,6 @@ def __init__(

if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit and len(self.tensor_parallel_devices) > 1:
load_in_8bit = False
logger.warning(
"Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
"You can explicitly set `--load_in_8bit True` to override this"
)
self.load_in_8bit = load_in_8bit
logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")

Expand Down

0 comments on commit 2116df0

Please sign in to comment.