Skip to content

Commit

Permalink
chore: make sure we have no pbs in hybrid model for linear layers in …
Browse files Browse the repository at this point in the history
…test (#764)
  • Loading branch information
jfrery authored Jun 26, 2024
1 parent 6d8a205 commit 7a2eeea
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def test_tuple_serialization(tup):


def run_hybrid_llm_test(
model: torch.nn.Module, inputs: torch.Tensor, module_names: Union[str, List], expected_accuracy
model: torch.nn.Module,
inputs: torch.Tensor,
module_names: Union[str, List],
expected_accuracy,
has_pbs: bool,
):
"""Run the test for any model with its private module names."""

Expand All @@ -49,6 +53,21 @@ def run_hybrid_llm_test(
inputs, p_error=0.01, n_bits=8, rounding_threshold_bits=8, configuration=configuration
)

if has_pbs:
# Check for non-zero programmable bootstrapping
for module in hybrid_model.private_q_modules.values():
assert module.fhe_circuit.statistics["programmable_bootstrap_count"] > 0, (
"Programmable bootstrap count should be greater than 0, "
f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}"
)
else:
# Check for zero programmable bootstrapping
for module in hybrid_model.private_q_modules.values():
assert module.fhe_circuit.statistics["programmable_bootstrap_count"] == 0, (
"Programmable bootstrap count should be 0, "
f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}"
)

# Check we can run the simulate locally
logits_simulate = hybrid_model(inputs, fhe="simulate").logits
logits_disable = hybrid_model(inputs, fhe="disable").logits
Expand Down Expand Up @@ -106,14 +125,14 @@ def run_hybrid_llm_test(
# 'from_pretrained' method
@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"list_or_str_private_modules_names, expected_accuracy",
"list_or_str_private_modules_names, expected_accuracy, has_pbs",
[
("transformer.h.0.mlp", 0.934),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42),
("transformer.h.0.mlp.c_fc", 0.986),
("transformer.h.0.mlp", 0.934, True),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42, True),
("transformer.h.0.mlp.c_fc", 0.986, False),
],
)
def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy):
def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy, has_pbs):
"""Test GPT2 hybrid."""

# Get GPT2 from Hugging Face
Expand All @@ -125,7 +144,9 @@ def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy):

# Run the test with using a single module in FHE
assert isinstance(model, torch.nn.Module)
run_hybrid_llm_test(model, input_ids, list_or_str_private_modules_names, expected_accuracy)
run_hybrid_llm_test(
model, input_ids, list_or_str_private_modules_names, expected_accuracy, has_pbs
)


def test_hybrid_brevitas_qat_model():
Expand Down

0 comments on commit 7a2eeea

Please sign in to comment.