Skip to content

Commit

Permalink
Merge pull request #917 from chandralegend/ph/persitant-inference-list
Browse files Browse the repository at this point in the history
Persistance for PH
  • Loading branch information
ypkang authored Feb 14, 2023
2 parents d48faac + 36b68a7 commit 8adafbf
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
40 changes: 38 additions & 2 deletions jaseci_ai_kit/jac_misc/jac_misc/ph/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import os
import shutil
from collections import OrderedDict
import json

from .utils import model as model_module
from .utils import process as process_module
from .utils.base import BaseInference
from .utils.logger import get_logger
import logging
from .utils.util import deep_update, write_yaml
from .utils.util import deep_update, write_yaml, read_yaml

from .train import train

Expand Down Expand Up @@ -91,8 +92,14 @@ class InferenceList:

def __init__(self, config: Dict = None) -> None: # type: ignore
self.config = config
with open("heads/config.yaml", "w") as f:
json.dump(config, f)
os.makedirs("heads", exist_ok=True)
self.ie_list = {}
self.ph_list = []
if os.path.exists("heads/ph_list.json"):
with open("heads/ph_list.json", "r") as f:
self.ph_list = json.load(f)

def add(self, config: Dict = None, uuid: str = None) -> str: # type: ignore
if self.check(uuid):
Expand All @@ -104,6 +111,14 @@ def add(self, config: Dict = None, uuid: str = None) -> str: # type: ignore
else:
ie = InferenceEngine(self.config, uuid)
self.ie_list[ie.id] = ie
if not os.path.exists("heads/ph_list.json"):
with open("heads/ph_list.json", "w") as f:
json.dump([], f)
with open("heads/ph_list.json", "r") as f:
self.ph_list = json.load(f)
self.ph_list.append(ie.id)
with open("heads/ph_list.json", "w") as f:
json.dump(self.ph_list, f)
return ie.id

def predict(self, uuid: str, data: Any) -> Any:
Expand Down Expand Up @@ -153,5 +168,26 @@ def get_config(self, uuid: str) -> Dict:
else:
raise ImproperConnectionState("Inference Engine not found.")

def delete_head(self, uuid: str) -> None:
if self.check(uuid):
del self.ie_list[uuid]
self.ph_list.remove(uuid)
with open("heads/ph_list.json", "w") as f:
json.dump(self.ph_list, f)
shutil.rmtree(f"heads/{uuid}")
else:
raise ImproperConnectionState("Inference Engine not found.")

def load_head(self, uuid: str) -> None:
config = read_yaml(f"heads/{uuid}/config.yaml")
self.ie_list[uuid] = InferenceEngine(config, uuid)
if os.path.exists(f"heads/{uuid}/current.pth"):
self.ie_list[uuid].load_weights(f"heads/{uuid}/current.pth")

def check(self, uuid: str) -> bool:
return uuid in self.ie_list
if uuid in self.ph_list and uuid in self.ie_list:
return True
if uuid in self.ph_list:
self.load_head(uuid)
return True
return False
8 changes: 7 additions & 1 deletion jaseci_ai_kit/jac_misc/jac_misc/ph/ph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import traceback
from fastapi import HTTPException
import logging

from jaseci.actions.live_actions import jaseci_action

Expand All @@ -21,7 +22,12 @@ def setup():
global il, list_config
dirname = os.path.dirname(__file__)
list_config = read_yaml(os.path.join(dirname, "config.yaml"))
il = None
if os.path.exists("heads/config.yaml") and os.path.exists("heads/custom.py"):
logging.warning("Found a heads list in the current directory. Loading it ...")
il = InferenceList(config=read_yaml("heads/config.yaml"))
else:
logging.info("No heads list found. Run create_head_list to create one.")
il = None


setup()
Expand Down

0 comments on commit 8adafbf

Please sign in to comment.