CRE-LLM: A Domain-Specific Chinese Relation Extraction Framework with Fine-tuned Large Language Model
- Python 3.8+
- PyTorch 1.13.1+
- 🤗Transformers 4.37.2+
- Datasets 2.14.3+
- Accelerate 0.27.2+
- PEFT 0.10.0+
- TRL 0.8.1+
- gradio>=4.0.0
And powerful GPUs!
git clone https://github.com/SkyuForever/CRE-LLM.git
conda create -n CRE-LLM python=3.10
conda activate CRE-LLM
cd CRE-LLM
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt
Experiments are conducted on 2 CRE benchmarks FinRE, SanWen.
FinRE dataset has been downloaded under data/FinRE
.
CRE-LLM/
└── data/
├── FinRE
├── train.txt
├── valid.txt
├── test.txt
└── relation2id.txt
SanWen dataset has been downloaded under data/SanWen
.
CRE-LLM/
└── data/
├── SanWen
├── train.txt
├── valid.txt
├── test.txt
└── relation2id.txt
Prepare data for training and evaluation
- FinRE:
Run python data_process.py
and the augmented dataset files are saved as data/FinRE/test[train].json
.
- SanWen
Run python data_process.py
and the augmented dataset files are saved as data/SanWen/test[train].json
.
Note
Please update data/dataset_info.json
to use your custom dataset. About the format of this file, please refer to data/README.md
.
You can also get the CRE-LLM processed data from our project directly.
Please remember to set it in data/
.
CRE-LLM/
└── data/
├── FinRE/
├── train.json
├── valid.json
└── test.json
└── SanWen/
├── train.json
├── valid.json
└── test.json
The following is an example of LLM fine-tuning and evaulation on FinRe and SanWen. We provide a variety of LLMs for our framework, as follows:
Model | Model size | Default module | Template |
---|---|---|---|
LLaMA-2 | 7B/13B/70B | q_proj,v_proj | llama2 |
Baichuan2 | 7B/13B | W_pack | baichuan2 |
ChatGLM2 | 6B | query_key_value | chatglm2 |
Note
For the "base" models, the --template argument can be chosen from default, alpaca, vicuna etc. But make sure to use the corresponding template for the "instruct/chat" models.
Remember to use the SAME template in training and inference.
For more information of fine-tuning, please refer to LLaMA-Efficient-Tuning.
(1) Supervised Fine-Tuning LLM for triple Generation
- FinRE:
Train LLMs for Triple Generation:
CUDA_VISIBLE_DEVICES=0 nohup python src/train_bash.py --model_name_or_path path_to_model --stage sft --do_train --dataset FinRE_train --template default --finetuning_type lora --lora_target q_proj,v_proj --output_dir path_to_sft_checkpoint --overwrite_cache --per_device_train_batch_size 4 --gradient_accumulation_steps 4 --lr_scheduler_type cosine --logging_steps 10 --save_steps 1000 --learning_rate 5e-5 --num_train_epochs 5 --plot_loss --fp16 >> FinRE_train.txt 2>&1 &
Test LLMs for Triple Generation:
CUDA_VISIBLE_DEVICES=3 nohup python -u src/train_bash.py --model_name_or_path path_to_model --stage sft --dataset FinRE_test --do_predict --template default --finetuning_type lora --checkpoint_dir path_to_checkpoint --output_dir path_to_predict_result --per_device_eval_batch_size 8 --max_samples 100 --predict_with_generate >> FinRE_test.txt 2>&1 &
- SanWen:
Train LLMs for Triple Generation:
CUDA_VISIBLE_DEVICES=0 nohup python src/train_bash.py --model_name_or_path path_to_model --stage sft --do_train --dataset SanWen_train --template default --finetuning_type lora --lora_target q_proj,v_proj --output_dir path_to_sft_checkpoint --overwrite_cache --per_device_train_batch_size 4 --gradient_accumulation_steps 4 --lr_scheduler_type cosine --logging_steps 10 --save_steps 1000 --learning_rate 5e-4 --num_train_epochs 10 --plot_loss --fp16 >> SanWen_train.txt 2>&1 &
Test LLMs for Triple Generation:
CUDA_VISIBLE_DEVICES=3 nohup python -u src/train_bash.py --model_name_or_path path_to_model --stage sft --dataset SanWen_test --do_predict --template default --finetuning_type lora --checkpoint_dir path_to_checkpoint --output_dir path_to_predict_result --per_device_eval_batch_size 8 --max_samples 100 --predict_with_generate >> SanWen_test.txt 2>&1 &
(2) Evaulate CRE result
Run python eval.py
to get the Accuracy, Recall and F1 score.