Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aqlm #233

Closed
wants to merge 8 commits into from
Closed

aqlm #233

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ on:
workflow_dispatch:
push:
branches:
- 'main'
- "main"
- "aqlm"
tags:
- 'v*'
- "v*"

jobs:
build-and-push-image:
Expand Down Expand Up @@ -41,8 +42,8 @@ jobs:
- name: Install soci
uses: lerentis/[email protected]
with:
soci-release: 'v0.4.0'
soci-release: "v0.4.0"

- name: Set up Docker Buildx
uses: docker/[email protected]

Expand All @@ -51,7 +52,7 @@ jobs:
with:
config-inline: |
version = 2

# persistent data location
root = "/var/lib/kubelet/containerd"

Expand All @@ -62,11 +63,8 @@ jobs:
images: |
ghcr.io/predibase/lorax
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=,suffix=,format=short
type=raw,value=latest

type=raw,value=aqlm

- name: Create a hash from tags
env:
tags: ${{ steps.meta.outputs.tags }}
Expand All @@ -93,7 +91,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile # Path to your Dockerfile
file: ./Dockerfile # Path to your Dockerfile
push: false
tags: ${{ steps.meta.outputs.tags }}
outputs: type=oci,compression=gzip,dest=${{ steps.vars.outputs.image_path }}-${{ steps.vars.outputs.tag_hash }}.tar.gz
Expand Down Expand Up @@ -124,7 +122,7 @@ jobs:
echo "Pushing $tag to GHCR"
sudo ctr i push --user "${{ github.repository_owner }}:${{ secrets.GHCR_PAT }}" $tag
done

- name: Create and push soci index
env:
tags: ${{ steps.meta.outputs.tags }}
Expand All @@ -151,4 +149,3 @@ jobs:

# Delete the SHA image(s) from containerd store
sudo ctr i rm $(sudo ctr i ls -q)

4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum Quantization {
Hqq_4bit,
Hqq_3bit,
Hqq_2bit,
Aqlm,
}

impl std::fmt::Display for Quantization {
Expand Down Expand Up @@ -63,6 +64,9 @@ impl std::fmt::Display for Quantization {
Quantization::Hqq_2bit => {
write!(f, "hqq-2bit")
}
Quantization::Aqlm => {
write!(f, "aqlm")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Quantization(str, Enum):
hqq_4bit = "hqq-4bit"
hqq_3bit = "hqq-3bit"
hqq_2bit = "hqq-2bit"
aqlm = "aqlm"


class Dtype(str, Enum):
Expand Down
6 changes: 3 additions & 3 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)

if model_type == "llama":
if model_type in ["llama", "llama_aqlm"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama_aqlm and similar custom model types are deprecated. We moved AQLM from custom code to transformers integration with transformers 4.38.0. Correct me if I'm wrong and it serves different purpose here.

if FLASH_ATTENTION:
return FlashLlama(
model_id,
Expand Down Expand Up @@ -306,7 +306,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)

if model_type == "mistral":
if model_type in ["mistral", "mistral_aqlm"]:
if MISTRAL:
return FlashMistral(
model_id,
Expand All @@ -320,7 +320,7 @@ def get_model(
)
raise NotImplementedError("Mistral model requires flash attention v2")

if model_type == "mixtral":
if model_type in ["mixtral", "mixtral_aqlm"]:
if MIXTRAL:
return FlashMixtral(
model_id,
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
elif config.quantize == "aqlm":
weights._set_aqlm_params(model_id)

model = FlashLlamaForCausalLM(config, weights)

Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def __init__(

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
elif config.quantize == "aqlm":
weights._set_aqlm_params(model_id)

model = FlashMistralForCausalLM(config, weights)

Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def __init__(

if config.quantize in ["gptq", "awq", "eetq"]:
weights._set_gptq_params(model_id)
elif config.quantize == "aqlm":
weights._set_aqlm_params(model_id)

model = FlashMixtralForCausalLM(config, weights)

Expand Down
15 changes: 15 additions & 0 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def weight(self) -> torch.Tensor:
except ImportError:
HAS_HQQ = False

HAS_AQLM = True
try:
from aqlm import QuantizedLinear
except ImportError:
HAS_AQLM = False

from accelerate import init_empty_weights

from lorax_server.utils.gptq.quant_linear import QuantLinear
Expand Down Expand Up @@ -385,6 +391,15 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False):
layer.bias.data = bias

linear = HQQLinearLayer(layer, quant_config, del_orig=True)
elif quantize == "aqlm":
scales, codebooks, codes, nbits_per_codebook, num_codebooks, out_group_size, in_group_size = weight
linear = QuantizedLinear(scales.shape[1], scales.shape[0], in_group_size, out_group_size, num_codebooks, nbits_per_codebook)
with torch.no_grad():
linear.scales = scales
linear.codebooks = codebooks
linear.codes = codes
if bias is not None:
linear.bias.data = bias
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
Expand Down
28 changes: 28 additions & 0 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str

bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
elif quantize == "aqlm":
nbits_per_codebook, num_codebooks, out_group_size, in_group_size = self._get_aqlm_params()
scales = self.get_sharded_list("scales", prefixes, dim=0)
codebooks = self.get_sharded_list("codebooks", prefixes, dim=0)
codes = self.get_sharded_list("codes", prefixes, dim=0)
weight = (scales, codebooks, codes, nbits_per_codebook, num_codebooks, out_group_size, in_group_size)
else:
w = self.get_sharded_list("weight", prefixes, dim=0)
weight = torch.cat(w, dim=dim)
Expand Down Expand Up @@ -290,6 +296,12 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "aqlm":
nbits_per_codebook, num_codebooks, out_group_size, in_group_size = self._get_aqlm_params()
scales = self.get_sharded(f"{prefix}.scales", dim=1)
codebooks = self.get_sharded(f"{prefix}.codebooks", dim=1)
codes = self.get_sharded(f"{prefix}.codes", dim=1)
weight = (scales, codebooks, codes, nbits_per_codebook, num_codebooks, out_group_size, in_group_size)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
Expand Down Expand Up @@ -342,6 +354,22 @@ def _set_gptq_params(self, model_id):
self.gptq_groupsize = data["q_group_size"]
except Exception:
pass

def _get_aqlm_params(self) -> Tuple[int, int, int, int]:
return self.nbits_per_codebook, self.num_codebooks, self.out_group_size, self.in_group_size

def _set_aqlm_params(self, model_id):
filename = "config.json"
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename)
with open(filename, "r") as f:
data = json.load(f)
self.nbits_per_codebook = data["aqlm"]["nbits_per_codebook"]
self.num_codebooks = data["aqlm"]["num_codebooks"]
self.out_group_size = data["aqlm"]["out_group_size"]
self.in_group_size = data["aqlm"]["in_group_size"]

def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
Expand Down
3 changes: 2 additions & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ boto3 = "^1.28.34"
urllib3 = "<=1.26.18"
hqq = { version = "^0.1.2", optional = true }
stanford-stk = { version = "^0.7.0", markers = "sys_platform == 'linux'" }
aqlm = { version = "^1.0.0"}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proper backprop was added in aqlm==1.1.0 and bf16 support was added in aqlm==1.1.2.


[tool.poetry.extras]
torch = ["torch"]
accelerate = ["accelerate"]
bnb = ["bitsandbytes"]
peft = ["peft"]
quantize = ["texttable", "datasets", "accelerate", "hqq"]
quantize = ["texttable", "datasets", "accelerate", "hqq", "aqlm"]

[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.51.1"
Expand Down
Loading