Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add flash attn to inference and eval scripts #132

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def parse_and_validate_args():
action="store_true",
)
parser.add_argument("--purge_results", action=argparse.BooleanOptionalAction)
parser.add_argument(
"--use_flash_attn",
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
parsed_args = parser.parse_args()

print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}")
Expand Down Expand Up @@ -441,7 +446,7 @@ def export_experiment_info(

if __name__ == "__main__":
args = parse_and_validate_args()
tuned_model = TunedCausalLM.load(args.model)
tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn)
eval_data = datasets.load_dataset(
"json", data_files=args.data_path, split=args.split
)
Expand Down
27 changes: 24 additions & 3 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def __init__(self, model, tokenizer, device):

@classmethod
def load(
cls, checkpoint_path: str, base_model_name_or_path: str = None
cls,
checkpoint_path: str,
base_model_name_or_path: str = None,
use_flash_attn: bool = False,
) -> "TunedCausalLM":
"""Loads an instance of this model.

Expand All @@ -152,6 +155,8 @@ def load(
adapter_config.json.
base_model_name_or_path: str [Default: None]
Override for the base model to be used.
use_flash_attn: bool [Default: False]
Whether to load the model using flash attention.

By default, the paths for the base model and tokenizer are contained within the adapter
config of the tuned model. Note that in this context, a path may refer to a model to be
Expand All @@ -173,14 +178,24 @@ def load(
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
model = AutoPeftModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
except OSError as e:
print("Failed to initialize checkpoint model!")
raise e
except FileNotFoundError:
print("No adapter config found! Loading as a merged model...")
# Unable to find the adapter config; fall back to loading as a merged model
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)

device = "cuda" if torch.cuda.is_available() else None
print(f"Inferred device: {device}")
Expand Down Expand Up @@ -246,6 +261,11 @@ def main():
type=int,
default=20,
)
parser.add_argument(
"--use_flash_attn",
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text", help="Text to run inference on")
group.add_argument(
Expand All @@ -261,6 +281,7 @@ def main():
loaded_model = TunedCausalLM.load(
checkpoint_path=args.model,
base_model_name_or_path=args.base_model_name_or_path,
use_flash_attn=args.use_flash_attn,
)

# Run inference on the text; if multiple were provided, process them all
Expand Down
Loading