Skip to content

Commit

Permalink
Merge pull request #30 from eubinecto/issue_28
Browse files Browse the repository at this point in the history
[#27] run_deploy.py done. with some formatting
  • Loading branch information
eubinecto authored Apr 7, 2022
2 parents b1c796e + 74815d4 commit e718b0a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
1 change: 1 addition & 0 deletions cleanrnns/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def fetch_model_for_ner():

def fetch_pipeline_for_classification(entity: str, name: str, run: Run = None) -> PipelineForClassification:
model = fetch_model_for_classification(entity, name, run)
model.eval()
tokenizer = fetch_tokenizer(entity, run)
config = fetch_config()[name]
pipeline = PipelineForClassification(model, tokenizer, config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def main():
pipeline = fetch_pipeline_for_classification("eubinecto", "lstm_for_classification")
print(pipeline("너무 좋다")) # just a sanity check
print(pipeline("재미없다")) # just a sanity check
# 향후 테스트를 진행할 계획
pipeline = fetch_pipeline_for_classification("eubinecto", "bilstm_for_classification")
print(pipeline("너무 좋다")) # just a sanity check
print(pipeline("재미없다")) # just a sanity check
shutil.rmtree("artifacts") # clear the cache after testing


Expand Down
28 changes: 18 additions & 10 deletions run_deploy.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,46 @@
"""
use streamlit to deploy them
"""
from typing import Tuple

import pandas as pd
import streamlit as st
from typing import Tuple
from cleanrnns.fetchers import fetch_pipeline_for_classification
from cleanrnns.pipelines import PipelineForClassification


@st.cache(allow_output_mutation=True)
def cache_pipeline() -> Tuple[PipelineForClassification, PipelineForClassification]:
def cache_pipeline() -> Tuple[PipelineForClassification,
PipelineForClassification,
PipelineForClassification]:
rnn = fetch_pipeline_for_classification("eubinecto", "rnn_for_classification")
lstm = fetch_pipeline_for_classification("eubinecto", "lstm_for_classification")
return rnn, lstm
bilstm = fetch_pipeline_for_classification("eubinecto", "bilstm_for_classification")
return rnn, lstm, bilstm


def main():
# fetch a pre-trained model
rnn, lstm = cache_pipeline()
rnn, lstm, bilstm = cache_pipeline()
st.title("The Clean Rnns - 긍 / 부정 감성분석")
text = st.text_input("문장을 입력하세요", value="난 너가 정말 좋으면서도 싫다")
text = st.text_input("문장을 입력하세요", value="제목은 시선을 끌지만 줄거리가 애매모호하다")
if st.button(label="분석하기"):
with st.spinner("Please wait..."):
with st.spinner("로딩중..."):
# prediction with RNN
table = list()
pred, probs = rnn(text)
sentiment = "`긍정`" if pred else "`부정`"
sentiment = "🟢(긍정)" if pred else "🔴(부정)"
probs = ["{:.4f}".format(prob) for prob in probs]
table.append(["RNN", sentiment, str(probs)])
# prediction with LSTM
pred, probs = lstm(text)
sentiment = "`긍정`" if pred else "`부정`"
probs = ["{:.4f}".format(prob) for prob in probs]
sentiment = "🟢(긍정)" if pred else "🔴(부정)"
table.append(["LSTM", sentiment, str(probs)])
df = pd.DataFrame(table, columns=["모델", "예측", "확률분포"])
pred, probs = bilstm(text)
sentiment = "🟢(긍정)" if pred else "🔴(부정)"
probs = ["{:.4f}".format(prob) for prob in probs]
table.append(["RNN", sentiment, str(probs)])
df = pd.DataFrame(table, columns=["모델", "예측", "확률분포 [부정, 긍정]"])
st.markdown(df.to_markdown(index=False))


Expand Down

0 comments on commit e718b0a

Please sign in to comment.