Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLM4 model #33729

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4403382
fix converter for function definitions
ArthurZucker Sep 30, 2024
18b2c0c
Hqq serialization (#33141)
mobicham Sep 30, 2024
30400c0
Create modular_glm.py
Cyrilvallez Sep 24, 2024
5677a55
Update modular_glm.py
Cyrilvallez Sep 25, 2024
c4c9f5c
Finalize architecture without all attentions
Cyrilvallez Sep 25, 2024
6443004
Add all attentions modules
Cyrilvallez Sep 25, 2024
ff01996
Finalize modular
Cyrilvallez Sep 25, 2024
7584246
Update given last version
Cyrilvallez Sep 25, 2024
603421e
Last update
Cyrilvallez Sep 25, 2024
9e0dfee
Finalize model
Cyrilvallez Sep 26, 2024
414100d
Finalize converter
Cyrilvallez Sep 26, 2024
b816507
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
590321b
style
Cyrilvallez Sep 26, 2024
85cbd60
style
Cyrilvallez Sep 26, 2024
7ba2f3a
Create __init__.py
Cyrilvallez Sep 26, 2024
fd727a6
Aff all inits
Cyrilvallez Sep 26, 2024
dfa54bb
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
ecd5bf4
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
ccacc3b
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
bd9b9ee
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
bba6b12
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
13934a8
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
678062d
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
967944f
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
2588ee7
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 26, 2024
8756c10
Correct the rotary embeddings
Cyrilvallez Sep 27, 2024
e633c22
Remove apply_residual_connection_post_layernorm (always false)
Cyrilvallez Sep 27, 2024
ea3ee4e
remove use_rms_norm (always true)
Cyrilvallez Sep 27, 2024
c2f0a8d
remove past_layer_norm (always true)
Cyrilvallez Sep 27, 2024
e251352
Update __init__.py
Cyrilvallez Sep 27, 2024
4fbcfce
Update config and license
Cyrilvallez Sep 27, 2024
a1692ab
start adding tests and doc
Cyrilvallez Sep 27, 2024
3c80274
Add doc + style
Cyrilvallez Sep 27, 2024
73ecb14
Update test_modeling_glm.py
Cyrilvallez Sep 27, 2024
d15d0e5
Add dummies
Cyrilvallez Sep 27, 2024
eacc0ad
Update back init (because __all__ is not generated from modular)
Cyrilvallez Sep 30, 2024
499f7a5
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 30, 2024
65085d2
Update convert_glm_weights_to_hf.py
Cyrilvallez Sep 30, 2024
d273a11
apply corrected modular
Cyrilvallez Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@
title: Gemma
- local: model_doc/gemma2
title: Gemma2
- local: model_doc/glm
title: GLM
- local: model_doc/openai-gpt
title: GPT
- local: model_doc/gpt_neo
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Flax), PyTorch, and/or TensorFlow.
| [ErnieM](model_doc/ernie_m) | ✅ | ❌ | ❌ |
| [ESM](model_doc/esm) | ✅ | ✅ | ❌ |
| [FairSeq Machine-Translation](model_doc/fsmt) | ✅ | ❌ | ❌ |
| [GLM](model_doc/glm) | ✅ | ❌ | ❌ |
| [Falcon](model_doc/falcon) | ✅ | ❌ | ❌ |
| [FalconMamba](model_doc/falcon_mamba) | ✅ | ❌ | ❌ |
| [FastSpeech2Conformer](model_doc/fastspeech2_conformer) | ✅ | ❌ | ❌ |
Expand Down
99 changes: 99 additions & 0 deletions docs/source/en/model_doc/glm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
<!--Copyright 2024 The GLM & ZhipuAI team and The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# GLM

## Overview

The GLM Model was proposed
in [ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools](https://arxiv.org/html/2406.12793v1)
by GLM Team, THUDM & ZhipuAI.

The abstract from the paper is the following:

*We introduce ChatGLM, an evolving family of large language models that we have been developing over time. This report
primarily focuses on the GLM-4 language series, which includes GLM-4, GLM-4-Air, and GLM-4-9B. They represent our most
capable models that are trained with all the insights and lessons gained from the preceding three generations of
ChatGLM. To date, the GLM-4 models are pre-trained on ten trillions of tokens mostly in Chinese and English, along with
a small set of corpus from 24 languages, and aligned primarily for Chinese and English usage. The high-quality alignment
is achieved via a multi-stage post-training process, which involves supervised fine-tuning and learning from human
feedback. Evaluations show that GLM-4 1) closely rivals or outperforms GPT-4 in terms of general metrics such as MMLU,
GSM8K, MATH, BBH, GPQA, and HumanEval, 2) gets close to GPT-4-Turbo in instruction following as measured by IFEval, 3)
matches GPT-4 Turbo (128K) and Claude 3 for long context tasks, and 4) outperforms GPT-4 in Chinese alignments as
measured by AlignBench. The GLM-4 All Tools model is further aligned to understand user intent and autonomously decide
when and which tool(s) to use—including web browser, Python interpreter, text-to-image model, and user-defined
functions—to effectively complete complex tasks. In practical applications, it matches and even surpasses GPT-4 All
Tools in tasks like accessing online information via web browsing and solving math problems using Python interpreter.
Over the course, we have open-sourced a series of models, including ChatGLM-6B (three generations), GLM-4-9B (128K, 1M),
GLM-4V-9B, WebGLM, and CodeGeeX, attracting over 10 million downloads on Hugging face in the year 2023 alone.*

Tips:

- This model was contributed by [THUDM](https://huggingface.co/THUDM). The most recent code can be
found [here](https://github.com/thudm/GLM-4).


## Usage tips

`GLM-4` can be found on the [Huggingface Hub](https://huggingface.co/collections/THUDM/glm-4-665fcf188c414b03c2f7e3b7)

In the following, we demonstrate how to use `glm-4-9b-chat` for the inference. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("THUDM/glm-4-9b-chat", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat")

>>> prompt = "Give me a short introduction to large language model."

>>> messages = [{"role": "user", "content": prompt}]

>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

>>> model_inputs = tokenizer([text], return_tensors="pt").to(device)

>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True)

>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
```

## GLMConfig

[[autodoc]] GLMConfig

## GLMModel

[[autodoc]] GLMModel
- forward

## GLMForCausalLM

[[autodoc]] GLMForCausalLM
- forward

## GLMForSequenceClassification

[[autodoc]] GLMForSequenceClassification
- forward

## GLMForTokenClassification

[[autodoc]] GLMForTokenClassification
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
Expand Down Expand Up @@ -214,6 +215,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
Expand Down
6 changes: 3 additions & 3 deletions docs/source/en/quantization/hqq.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

# Method 1: all linear layers will use the same quantization config
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
quant_config = HqqConfig(nbits=8, group_size=64)
```

``` Python
# Method 2: each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
q4_config = {'nbits':4, 'group_size':64}
q3_config = {'nbits':3, 'group_size':32}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@
"Kosmos2Config",
"Kosmos2Processor",
],
"models.glm": ["GlmConfig"],
"models.layoutlm": [
"LayoutLMConfig",
"LayoutLMTokenizer",
Expand Down Expand Up @@ -2549,6 +2550,15 @@
"LlamaPreTrainedModel",
]
)
_import_structure["models.glm"].extend(
[
"GlmForCausalLM",
"GlmForSequenceClassification",
"GlmForTokenClassification",
"GlmModel",
"GlmPreTrainedModel",
]
)
_import_structure["models.llava"].extend(
[
"LlavaForConditionalGeneration",
Expand Down Expand Up @@ -5266,6 +5276,7 @@
GitProcessor,
GitVisionConfig,
)
from .models.glm import GlmConfig
from .models.glpn import GLPNConfig
from .models.gpt2 import (
GPT2Config,
Expand Down Expand Up @@ -6977,6 +6988,13 @@
GitPreTrainedModel,
GitVisionModel,
)
from .models.glm import (
GlmForCausalLM,
GlmForSequenceClassification,
GlmForTokenClassification,
GlmModel,
GlmPreTrainedModel,
)
from .models.glpn import (
GLPNForDepthEstimation,
GLPNModel,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_

has_been_replaced = True

# Add these fake parameters to avoid loading fail
for att in ["W_q", "meta"]:
setattr(module, att, None)

if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
Expand Down Expand Up @@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve

# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
quant_config = quantization_config.quant_config
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))

if any(key in linear_tags for key in quant_config.keys()):
Expand All @@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
)

# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
model.config.quantization_config = {
"quant_config": quant_config,
"quant_method": quantization_config.quant_method,
"skip_modules": skip_modules,
}

if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29

old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
# Not all the attributes of a module are Parameters/Tensor
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is None:
break

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
Expand Down Expand Up @@ -3819,6 +3824,7 @@ def from_pretrained(
from_pt = not (from_tf | from_flax)

# load pt weights early so that we know which dtype to init the model under

if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
Expand Down Expand Up @@ -4176,6 +4182,9 @@ def _load_pretrained_model(
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix

if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)

def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
Expand Down Expand Up @@ -4290,7 +4299,7 @@ def _fix_key(key):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
gemma,
gemma2,
git,
glm,
glpn,
gpt2,
gpt_bigcode,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("gemma", "GemmaConfig"),
("gemma2", "Gemma2Config"),
("git", "GitConfig"),
("glm", "GlmConfig"),
("glpn", "GLPNConfig"),
("gpt-sw3", "GPT2Config"),
("gpt2", "GPT2Config"),
Expand Down Expand Up @@ -413,6 +414,7 @@
("gemma", "Gemma"),
("gemma2", "Gemma2"),
("git", "GIT"),
("glm", "GLM"),
("glpn", "GLPN"),
("gpt-sw3", "GPT-Sw3"),
("gpt2", "OpenAI GPT-2"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("git", "GitModel"),
("glm", "GlmModel"),
("glpn", "GLPNModel"),
("gpt-sw3", "GPT2Model"),
("gpt2", "GPT2Model"),
Expand Down Expand Up @@ -483,6 +484,7 @@
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("gpt_bigcode", "GPTBigCodeForCausalLM"),
Expand Down Expand Up @@ -909,6 +911,7 @@
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gemma2", "Gemma2ForSequenceClassification"),
("glm", "GlmForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
Expand Down Expand Up @@ -1093,6 +1096,7 @@
("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gemma2", "Gemma2ForTokenClassification"),
("glm", "GlmForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
Expand Down
60 changes: 60 additions & 0 deletions src/transformers/models/glm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_glm": ["GlmConfig"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_glm"] = [
"GlmForCausalLM",
"GlmModel",
"GlmPreTrainedModel",
"GlmForSequenceClassification",
"GlmForTokenClassification",
]

if TYPE_CHECKING:
from .configuration_glm import GlmConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_glm import (
GlmForCausalLM,
GlmForSequenceClassification,
GlmForTokenClassification,
GlmModel,
GlmPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading
Loading