Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Deepspeed shows incorrect results for CodeLlama #4442

Closed
cupertank opened this issue Oct 3, 2023 · 2 comments
Closed

[BUG] Deepspeed shows incorrect results for CodeLlama #4442

cupertank opened this issue Oct 3, 2023 · 2 comments
Assignees
Labels
bug Something isn't working inference

Comments

@cupertank
Copy link
Contributor

cupertank commented Oct 3, 2023

Describe the bug
CodeLlama with DeepSpeed shows incorrect results. During my investigation, I found that DeepSpeed has hardcoded rope_theta == 10000.0 in rotary embedding, while for CodeLlama rope_theta == 1000000.0.
Line with bug:

inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_idx;

rope_theta in CodeLlama config

I think rope_theta must be a parameter in rotary embedding

To Reproduce
Steps to reproduce the behavior:

  1. Run this script:
from transformers import AutoModelForCausalLM, AutoTokenizer
from deepspeed import init_inference
import torch

test_input = """import abc
import gzip
import logging
import multiprocessing
import os
import sys
from multiprocessing import Pool
from typing import Iterable, Sequence

from tqdm.auto import tqdm

logger = logging."""

tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", torch_dtype=torch.float16)
inputs = tokenizer(test_input, return_tensors="pt")["input_ids"].to("cuda")

model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to("cuda")

transformers_result = model.generate(
    inputs,
    max_new_tokens=10
).to("cpu")

transformers_text = tokenizer.decode(transformers_result[0])

print("==============Transformers result==============")
print(transformers_text)
print("===============================================")


model = init_inference(
    model=model,
    dtype=torch.float16,
    replace_with_kernel_inject=True
)

deepspeed_outputs = model.generate(
    inputs,
    max_new_tokens=10
).to("cpu")

deepspeed_text = tokenizer.decode(deepspeed_outputs[0])

print("===============DeepSpeed result================")
print(deepspeed_text)
print("===============================================")

My output:

==============Transformers result==============
<s> import abc
import gzip
import logging
import multiprocessing
import os
import sys
from multiprocessing import Pool
from typing import Iterable, Sequence

from tqdm.auto import tqdm

logger = logging.getLogger(__name__)


class
===============================================
...
===============DeepSpeed result================
<s> import abc
import gzip
import logging
import multiprocessing
import os
import sys
from multiprocessing import Pool
from typing import Iterable, Sequence

from tqdm.auto import tqdm

logger = logging.getLogger(__name__))




===============================================

Expected behavior
I expected the same result in both engines

ds_report output

[2023-10-03 13:20:28,866] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/ilya_vologin/giga_pizda/venv/lib/python3.8/site-packages/torch']
torch version .................... 1.13.1+cu117
deepspeed install path ........... ['/home/ilya_vologin/giga_pizda/venv/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.10.3, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 1.13, cuda 11.7
shared memory (/dev/shm) size .... 83.53 GB

System info (please complete the following information):

  • OS: Debian 5.10.191-1 (2023-08-16) x86_64 GNU/Linux
  • GPU count and types: One machine with one A100 80GB
  • Hugging Face Transformers versions: transformers==4.33.2
  • Python version: 3.8.18
  • Any other relevant info about your setup

Additional context
If you change 10000.0 to 1000000.0 in this line:

inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_idx;

You will get correct results:

==============Transformers result==============
<s> import abc
import gzip
import logging
import multiprocessing
import os
import sys
from multiprocessing import Pool
from typing import Iterable, Sequence

from tqdm.auto import tqdm

logger = logging.getLogger(__name__)


class
===============================================
...
===============DeepSpeed result================
<s> import abc
import gzip
import logging
import multiprocessing
import os
import sys
from multiprocessing import Pool
from typing import Iterable, Sequence

from tqdm.auto import tqdm

logger = logging.getLogger(__name__)


class
===============================================
@cupertank cupertank added bug Something isn't working inference labels Oct 3, 2023
@mrwyattii mrwyattii self-assigned this Oct 4, 2023
@mrwyattii
Copy link
Contributor

@cupertank thank you for reporting and finding the cause of this bug! I can work on getting a PR that will correct this (unless you planned to create a PR yourself).

@cupertank
Copy link
Contributor Author

@mrwyattii I made Pull request with bugfix, can you please take a look at it #4480?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

2 participants