Skip to content

Commit

Permalink
customized function to get response for sagemaker
Browse files Browse the repository at this point in the history
  • Loading branch information
HuXiangkun committed Sep 5, 2024
1 parent c7ceebc commit c575b53
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "refchecker"
version = "0.2.8"
version = "0.2.9"
description = "RefChecker provides automatic checking pipeline for detecting fine-grained hallucinations generated by Large Language Models."
authors = [
"Xiangkun Hu <[email protected]>",
Expand Down
5 changes: 4 additions & 1 deletion refchecker/checker/checker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def check(
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
"""
Expand Down Expand Up @@ -98,6 +99,7 @@ def check(
joint_check_num=joint_check_num,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)
if merge_psg:
Expand Down Expand Up @@ -138,7 +140,8 @@ def check(
questions=[inp[3] for inp in input_flattened],
is_joint=False,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func
)

ret = [[x] + y for x, y in zip(ret, input_ids)]
Expand Down
3 changes: 3 additions & 0 deletions refchecker/checker/llm_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _check(
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
"""
Expand Down Expand Up @@ -128,6 +129,7 @@ def _check(
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)

Expand Down Expand Up @@ -208,6 +210,7 @@ def _check(
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)

Expand Down
5 changes: 5 additions & 0 deletions refchecker/extractor/extractor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def extract(
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
if self.claim_format == 'triplet':
Expand All @@ -27,6 +28,7 @@ def extract(
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)
elif self.claim_format == 'subsentence':
Expand All @@ -36,6 +38,7 @@ def extract(
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)
return result
Expand All @@ -47,6 +50,7 @@ def extract_claim_triplets(
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
raise NotImplementedError
Expand All @@ -58,6 +62,7 @@ def extract_subsentence_claims(
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions refchecker/extractor/llm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def extract_subsentence_claims(
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
"""Extract subsentence claims from the response text.
Expand Down Expand Up @@ -76,6 +77,7 @@ def extract_subsentence_claims(
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)

Expand Down Expand Up @@ -103,6 +105,7 @@ def extract_claim_triplets(
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
"""Extract KG triplets from the response text.
Expand Down Expand Up @@ -150,6 +153,7 @@ def extract_claim_triplets(
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
**kwargs
)

Expand Down
9 changes: 6 additions & 3 deletions refchecker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_model_batch_response(
api_base=None,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
**kwargs
):
"""
Expand Down Expand Up @@ -120,9 +121,11 @@ def get_model_batch_response(
),
ContentType="application/json",
)

r = json.loads(r['Body'].read().decode('utf8'))
response = r['outputs'][0]
if sagemaker_get_response_func is not None:
response = sagemaker_get_response_func(r)
else:
r = json.loads(r['Body'].read().decode('utf8'))
response = r['outputs'][0]
response_list.append(response)
return response_list
else:
Expand Down

0 comments on commit c575b53

Please sign in to comment.