Skip to content

Commit

Permalink
Refactor dead code - Removing all flash_xxx.py files. (#2166)
Browse files Browse the repository at this point in the history
* Refactor dead code.

* First working step.

* Remove a lot of duplicated code.

* More dead code.

* More cleanup.

* Fix Santacoder test.

* Fixing the simple tests.

* Fixing sharding.

* Fixes for VLM.

* Fixing santacoder (num_kv_heads hardcoded).

* Removing more dead code.

* Fixing `config.n_head`.

* Stopping earlier because of `<end_of_utterance>` in idefics2.

* Addresses comments.

* Removing the dead code.

* Fuse back mistral into FlashCausalLM.

* Finish removal.

* Fixing docs + causal_lm `batch_class`.

* Fixing docs + causal.lm.

* Add default to Gemma Causality.

* Default value for gemma/gemma2.

* Wrong default.
  • Loading branch information
Narsil authored and ErikKaum committed Jul 26, 2024
1 parent f598d24 commit f6982b8
Show file tree
Hide file tree
Showing 43 changed files with 689 additions and 2,451 deletions.
2 changes: 1 addition & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.1.1-dev0"
"version": "2.1.2-dev0"
},
"paths": {
"/": {
Expand Down
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,130 +1,124 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"finish_reason": "eos_token",
"generated_tokens": 19,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 415,
"logprob": -0.039886475,
"logprob": -0.03665161,
"special": false,
"text": " The"
},
{
"id": 12072,
"logprob": -0.1430664,
"logprob": -0.13549805,
"special": false,
"text": " cow"
},
{
"id": 349,
"logprob": -0.056488037,
"logprob": -0.05819702,
"special": false,
"text": " is"
},
{
"id": 6328,
"logprob": -0.6855469,
"logprob": -0.6826172,
"special": false,
"text": " standing"
},
{
"id": 356,
"logprob": -0.1685791,
"logprob": -0.1607666,
"special": false,
"text": " on"
},
{
"id": 272,
"logprob": -0.50097656,
"logprob": -0.5073242,
"special": false,
"text": " the"
},
{
"id": 10305,
"logprob": -0.017303467,
"logprob": -0.016418457,
"special": false,
"text": " beach"
},
{
"id": 304,
"logprob": -1.3564453,
"logprob": -1.3916016,
"special": false,
"text": " and"
},
{
"id": 272,
"logprob": -0.017868042,
"logprob": -0.020217896,
"special": false,
"text": " the"
},
{
"id": 13088,
"logprob": -0.0027103424,
"logprob": -0.0028133392,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.003156662,
"logprob": -0.003145218,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.37304688,
"logprob": -0.37060547,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.034576416,
"logprob": -0.034851074,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.29418945,
"logprob": -0.2878418,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.042877197,
"logprob": -0.046051025,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.00028443336,
"logprob": -0.00028848648,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.023223877,
"logprob": -0.025772095,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.018157959,
"logprob": -0.018127441,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.00018393993,
"logprob": -0.00019824505,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/models/test_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20
assert response.details.generated_tokens == 19
assert response == response_snapshot


Expand Down
8 changes: 7 additions & 1 deletion server/tests/models/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)


@pytest.fixture(scope="session")
Expand All @@ -16,7 +19,10 @@ def default_bloom():
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)


@pytest.fixture(scope="session")
Expand Down
2 changes: 1 addition & 1 deletion server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
return CausalLM.fallback("gpt2")


@pytest.fixture(scope="session")
Expand Down
5 changes: 2 additions & 3 deletions server/tests/models/test_santacoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pytest

from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM


@pytest.fixture(scope="session")
def default_santacoder():
return SantaCoder("bigcode/santacoder")
return CausalLM.fallback(model_id="bigcode/santacoder")


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion server/tests/models/test_seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def mt0_small_tokenizer():

@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
return Seq2SeqLM.fallback("bigscience/mt0-small")


@pytest.fixture
Expand Down
Loading

0 comments on commit f6982b8

Please sign in to comment.