forked from kookmin-sw/cap-template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from kookmin-sw/master
pr
- Loading branch information
Showing
31 changed files
with
417 additions
and
43 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# db_store.py | ||
|
||
from model import load_model_and_tokenizer, predict_entities | ||
from data_processing import find_career_status, find_phone_number, extract_and_combine_entities | ||
from datasets import load_dataset | ||
from config.db import connect_db, get_collection | ||
|
||
# MongoDB 데이터베이스 연결 | ||
db = connect_db() | ||
collection = db['ExtractedEntities'] # 원하는 컬렉션 이름을 지정 | ||
|
||
# KLUE NER 데이터셋 로드 | ||
dataset = load_dataset("klue", "ner") | ||
tag_list = dataset['train'].features['ner_tags'].feature.names | ||
tag2id = {tag: id for id, tag in enumerate(tag_list)} | ||
id2tag = {id: tag for tag, id in tag2id.items()} | ||
|
||
# 모델 및 토크나이저 로드 | ||
model, tokenizer = load_model_and_tokenizer() | ||
|
||
# 예시 텍스트 | ||
text = "25/ 김준호 /서초구 거주/경력:유/전화번호:010-0000-0000" | ||
|
||
# 엔티티 추출 및 결합 | ||
predicted_entities = predict_entities(text, model, tokenizer, id2tag) | ||
entities_combined = extract_and_combine_entities(predicted_entities) | ||
entities_combined["career"] = find_career_status(text) | ||
entities_combined["phonenumber"] = find_phone_number(text) | ||
entities_combined["sex"] = "남" | ||
entities_combined["RRN"] = "000000-0000000" | ||
|
||
# 데이터 MongoDB에 저장 | ||
insert_result = collection.insert_one(entities_combined) | ||
print(f"삽입된 문서 ID: {insert_result.inserted_id}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from config.db import connect_db | ||
import pprint | ||
|
||
db = connect_db() | ||
collection = db['ExtractedEntities'] | ||
db.list_collection_names() | ||
|
||
for document in collection.find(): | ||
print(document) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from transformers import BertTokenizerFast, BertForTokenClassification, AdamW | ||
from transformers import Trainer, TrainingArguments | ||
from datasets import load_dataset, load_metric | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch.nn.utils.rnn import pad_sequence | ||
import numpy as np | ||
from kobert_tokenizer import KoBERTTokenizer | ||
from kobert_transformers import get_kobert_model, get_tokenizer | ||
from transformers import BertForTokenClassification | ||
|
||
# KLUE NER 데이터셋 로드 | ||
dataset = load_dataset("klue", "ner") | ||
|
||
# 태그 리스트 확인 | ||
tag_list = dataset['train'].features['ner_tags'].feature.names | ||
print(tag_list) | ||
|
||
# tag2id 및 id2tag 사전 생성 | ||
tag2id = {tag: id for id, tag in enumerate(tag_list)} | ||
id2tag = {id: tag for tag, id in tag2id.items()} | ||
|
||
model_name = "mmoonssun/klue_ner_kobert" | ||
model = BertForTokenClassification.from_pretrained(model_name, num_labels=13) # num_labels는 데이터셋의 라벨 수에 맞춰 조정 | ||
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1') | ||
|
||
import re | ||
|
||
def predict_entities(text, model, tokenizer, id2tag): | ||
# GPU 사용 설정 | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model.to(device) | ||
|
||
# 평가 모드로 설정 | ||
model.eval() | ||
|
||
# 입력 문장 토크나이징 및 텐서로 변환 | ||
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512) | ||
input_ids = inputs["input_ids"].to(device) | ||
attention_mask = inputs["attention_mask"].to(device) | ||
|
||
# 예측 수행 | ||
with torch.no_grad(): | ||
outputs = model(input_ids, attention_mask=attention_mask) | ||
logits = outputs.logits | ||
|
||
# 예측 결과에서 가장 높은 확률을 가진 태그 ID를 추출 | ||
predictions = torch.argmax(logits, dim=2) | ||
|
||
# ID를 태그로 변환 | ||
predicted_tags = [id2tag[id.item()] for id in predictions[0]] | ||
|
||
# 토큰화된 텍스트와 예측된 태그 결합 | ||
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) | ||
token_tag_pairs = [(token, tag) for token, tag in zip(tokens, predicted_tags) if token not in ["[CLS]", "[SEP]", "[PAD]", "<pad>"]] | ||
|
||
# '▁' 문자를 공백으로 대체하여 보다 자연스러운 출력을 생성 | ||
token_tag_pairs = [(token.replace('▁', ' '), tag) for token, tag in token_tag_pairs] | ||
|
||
return token_tag_pairs | ||
|
||
def find_career_status(text): | ||
# '경력' 다음에 오는 '유', '무', '없', '있' 찾기 | ||
pattern = r'경력\s*:\s*(유|무|없|있)' | ||
|
||
# 문자열에서 패턴에 해당하는 부분 찾기 | ||
match = re.search(pattern, text) | ||
|
||
# 찾은 값을 변수에 저장하고 처리 | ||
if match: | ||
raw_career = match.group(1) # 첫 번째 그룹(유|무|없|있)을 추출 | ||
# '없'이나 '있'을 각각 '무', '유'로 변환 | ||
if raw_career == '없': | ||
career = '무' | ||
elif raw_career == '있': | ||
career = '유' | ||
else: | ||
career = raw_career | ||
|
||
return career | ||
else: | ||
return "경력 유무를 찾을 수 없습니다." | ||
|
||
def find_phone_number(text): | ||
# 정규 표현식으로 전화번호 패턴 찾기 | ||
# 패턴 설명: '010'으로 시작하며, '-'가 있을 수도 있고 없을 수도 있으며, 숫자가 연속으로 나타남 | ||
pattern = r'010-?\d{4}-?\d{4}' | ||
|
||
# 문자열에서 패턴에 해당하는 부분 찾기 | ||
match = re.search(pattern, text) | ||
|
||
# 찾은 전화번호를 변수에 저장하고 출력 | ||
if match: | ||
phone_number = match.group() | ||
return phone_number | ||
else: | ||
return "전화번호를 찾을 수 없습니다." | ||
|
||
def extract_and_combine_entities(predicted_entities): | ||
name = "" | ||
location = "" | ||
age = "" | ||
|
||
for token, tag in predicted_entities: | ||
if tag == 'B-PS': # 이름 추출 | ||
name += token.strip() | ||
elif tag == 'B-LC': # 위치 추출 | ||
if token == " ": | ||
location += token | ||
else: | ||
location += token.strip() | ||
elif tag == 'B-QT' or tag == 'B-DT': # 나이(수량) 추출 | ||
age = token.strip() | ||
|
||
# 결과 반환 | ||
return {"name": name, "location": location, "age": age} | ||
|
||
# 예시 문장 | ||
text = "25/ 김준호 /서초구 거주/경력:유/전화번호:010-0000-0000" | ||
|
||
# 모델을 사용하여 문장에서 개체 추출 | ||
predicted_entities = predict_entities(text, model, tokenizer, id2tag) | ||
career = find_career_status(text) | ||
phone_number = find_phone_number(text) | ||
|
||
# 함수를 사용하여 변수에 저장 | ||
entities_combined = extract_and_combine_entities(predicted_entities) | ||
|
||
# 결과 출력 | ||
entities_combined["career"] = career | ||
entities_combined["phone_number"] = phone_number | ||
|
||
print(entities_combined) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"name": "capstone-2024-35", | ||
"version": "1.0.0", | ||
"description": "Capstone Project 2024-35", | ||
"main": "index.js", | ||
"scripts": { | ||
"start": "node src/mongodb_store.js" | ||
}, | ||
"dependencies": { | ||
"dotenv": "^16.0.0", | ||
"mongodb": "^4.10.0" | ||
}, | ||
"author": "", | ||
"license": "ISC" | ||
} |
File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# src/config/db.py | ||
|
||
import os | ||
from dotenv import load_dotenv | ||
from pymongo import MongoClient | ||
|
||
# 환경 변수 로드 | ||
load_dotenv() | ||
mongodb_uri = os.getenv('MONGODB_URI') | ||
|
||
# MongoDB 연결 설정 | ||
def connect_db(): | ||
client = MongoClient(mongodb_uri) | ||
db = client['Authusers'] # 데이터베이스 이름을 여기에서 변경 가능 | ||
return db | ||
def get_collection(collection_name): | ||
db = connect_db() | ||
return db[collection_name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# employee.py | ||
|
||
from datetime import datetime | ||
from bson import ObjectId | ||
from config.db import get_collection | ||
|
||
class Employee: | ||
def __init__(self, user, name, sex, local, rrn, phonenumber, created_at=None, updated_at=None): | ||
self.user = ObjectId(user) | ||
self.name = name | ||
self.sex = sex | ||
self.local = local | ||
self.rrn = rrn # 주민등록번호 | ||
self.phonenumber = phonenumber | ||
self.created_at = created_at if created_at else datetime.utcnow() | ||
self.updated_at = updated_at if updated_at else datetime.utcnow() | ||
|
||
def to_dict(self): | ||
return { | ||
'user': self.user, | ||
'name': self.name, | ||
'sex': self.sex, | ||
'local': self.local, | ||
'RRN': self.rrn, | ||
'phonenumber': self.phonenumber, | ||
'createdAt': self.created_at, | ||
'updatedAt': self.updated_at | ||
} | ||
|
||
class EmployeeRepository: | ||
""" | ||
Employee 데이터를 관리하는 저장소 클래스. | ||
""" | ||
def __init__(self): | ||
self.collection = get_collection('ExtractedEntities') | ||
|
||
def insert(self, employee: Employee): | ||
""" | ||
새로운 Employee를 삽입합니다. | ||
""" | ||
self.collection.insert_one(employee.to_dict()) | ||
|
||
def find_all(self): | ||
""" | ||
모든 Employee 데이터를 반환합니다. | ||
""" | ||
return list(self.collection.find()) | ||
|
||
def find_by_name(self, name): | ||
""" | ||
이름으로 Employee 데이터를 찾습니다. | ||
""" | ||
return list(self.collection.find({'name': name})) | ||
|
||
def update(self, employee_id, updated_fields): | ||
""" | ||
주어진 Employee ID의 데이터를 업데이트합니다. | ||
""" | ||
self.collection.update_one({'_id': ObjectId(employee_id)}, {'$set': updated_fields}) | ||
|
||
def delete(self, employee_id): | ||
""" | ||
주어진 Employee ID의 데이터를 삭제합니다. | ||
""" | ||
self.collection.delete_one({'_id': ObjectId(employee_id)}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# db_store.py | ||
|
||
from model import load_model_and_tokenizer, predict_entities | ||
from data_processing import find_career_status, find_phone_number, extract_and_combine_entities | ||
from datasets import load_dataset | ||
from config.db import connect_db, get_collection | ||
from employee import Employee, EmployeeRepository | ||
|
||
# MongoDB 데이터베이스 연결 | ||
db = connect_db() | ||
collection = get_collection('ExtractedEntities') # 원하는 컬렉션 이름을 지정 | ||
|
||
# KLUE NER 데이터셋 로드 | ||
dataset = load_dataset("klue", "ner") | ||
tag_list = dataset['train'].features['ner_tags'].feature.names | ||
tag2id = {tag: id for id, tag in enumerate(tag_list)} | ||
id2tag = {id: tag for tag, id in tag2id.items()} | ||
|
||
# 모델 및 토크나이저 로드 | ||
model, tokenizer = load_model_and_tokenizer() | ||
|
||
# 예시 텍스트 | ||
text = "25/ 김준호 /서초구 거주/경력:유/전화번호:010-0000-0000" | ||
|
||
# 엔티티 추출 및 결합 | ||
predicted_entities = predict_entities(text, model, tokenizer, id2tag) | ||
entities_combined = extract_and_combine_entities(predicted_entities) | ||
entities_combined["career"] = find_career_status(text) | ||
entities_combined["phonenumber"] = find_phone_number(text) | ||
entities_combined["sex"] = "남" | ||
entities_combined["RRN"] = "000000-0000000" | ||
|
||
user_id = '609b8b8f8e4f5b88f8e8e8e8' | ||
|
||
new_employee = Employee( | ||
user=user_id, | ||
name=entities_combined["name"], | ||
sex=entities_combined["sex"], | ||
local=entities_combined["local"], | ||
rrn=entities_combined["RRN"], | ||
phonenumber=entities_combined["phonenumber"] | ||
) | ||
# 데이터 MongoDB에 저장 | ||
employee_repo = EmployeeRepository() | ||
employee_repo.insert(new_employee) | ||
|
||
for employee in employee_repo.find_all(): | ||
print(employee) |
File renamed without changes.
6 changes: 3 additions & 3 deletions
6
Data Extract/src/requirements.txt → DataExtract/src/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
torch | ||
transformers | ||
transformers[torch] | ||
datasets | ||
seqeval | ||
kobert-transformers | ||
git+https://github.com/SKTBrain/KoBERT.git@master#egg=kobert | ||
onnxruntime==1.8.0 | ||
git+https://github.com/SKTBrain/KoBERT.git@master | ||
git+https://github.com/SKTBrain/KoBERT.git#egg=kobert_tokenizer&subdirectory=kobert_hf | ||
scikit-learn |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.