Skip to content

Commit

Permalink
Further reorganize artifact structure (octoml#156)
Browse files Browse the repository at this point in the history
This PR reorganizes the artifact structure. We now have two separate
types of directories to store the libs/weights/..., with one "prebuilt"
directory which holds all the prebuilt libs and weights downloaded from
internet, and other model directories that are generated by local
builds.

CLI and test scripts are updated accordingly for this change.
  • Loading branch information
MasterJH5574 authored May 16, 2023
1 parent 801f573 commit 60a409b
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 76 deletions.
49 changes: 38 additions & 11 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,24 @@

def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument(
"--model-path",
type=str,
default=None,
help="Custom model path that contains params, tokenizer, and config",
)
args.add_argument(
"--hf-path",
type=str,
default=None,
help="Hugging Face path from which to download params, tokenizer, and config from",
)
args.add_argument(
"--quantization",
type=str,
choices=[*utils.quantization_dict.keys()],
default=list(utils.quantization_dict.keys())[0],
)
args.add_argument("--max-seq-len", type=int, default=-1)
args.add_argument("--target", type=str, default="auto")
args.add_argument(
Expand Down Expand Up @@ -62,9 +79,12 @@ def _parse_args():

return parsed


def _setup_model_path(args):
if args.model_path and args.hf_path:
assert (args.model_path and not args.hf_path) or (args.hf_path and not args.model_path), "You cannot specify both a model path and a HF path. Please select one to specify."
assert (args.model_path and not args.hf_path) or (
args.hf_path and not args.model_path
), "You cannot specify both a model path and a HF path. Please select one to specify."
if args.model_path:
validate_config(args)
with open(os.path.join(args.model_path, "config.json")) as f:
Expand All @@ -78,20 +98,30 @@ def _setup_model_path(args):
else:
os.makedirs(args.model_path, exist_ok=True)
os.system("git lfs install")
os.system(f"git clone https://huggingface.co/{args.hf_path} {args.model_path}")
os.system(
f"git clone https://huggingface.co/{args.hf_path} {args.model_path}"
)
print(f"Downloaded weights to {args.model_path}")
validate_config(args)
else:
raise ValueError(f"Please specify either the model_path or the hf_path.")
print(f"Using model path {args.model_path}")
return args


def validate_config(args):
assert os.path.exists(os.path.join(args.model_path, "config.json")), "Model path must contain valid config file."
assert os.path.exists(
os.path.join(args.model_path, "config.json")
), "Model path must contain valid config file."
with open(os.path.join(args.model_path, "config.json")) as f:
config = json.load(f)
assert ("model_type" in config) and ("_name_or_path" in config), "Invalid config format."
assert config["model_type"] in utils.supported_model_types, f"Model type {config['model_type']} not supported."
assert ("model_type" in config) and (
"_name_or_path" in config
), "Invalid config format."
assert (
config["model_type"] in utils.supported_model_types
), f"Model type {config['model_type']} not supported."


def debug_dump_script(mod, name, args):
"""Debug dump mode"""
Expand Down Expand Up @@ -177,7 +207,7 @@ def dump_default_mlc_llm_config(args):
config["stream_interval"] = 2
config["mean_gen_len"] = 128
config["shift_fill_factor"] = 0.3
dump_path = os.path.join(args.artifact_path, "mlc_llm_config.json")
dump_path = os.path.join(args.artifact_path, "params", "mlc-llm-config.json")
with open(dump_path, "w") as outfile:
json.dump(config, outfile, indent=4)
print(f"Finish exporting mlc_llm_config to {dump_path}")
Expand Down Expand Up @@ -255,10 +285,7 @@ def dump_split_tir(mod: tvm.IRModule):
mod, params = llama.get_model(ARGS, config)
elif ARGS.model_category == "gpt_neox":
mod, params = gpt_neox.get_model(
ARGS.model,
ARGS.model_path,
ARGS.quantization.model_dtype,
config
ARGS.model, ARGS.model_path, ARGS.quantization.model_dtype, config
)
elif ARGS.model_category == "moss":
mod, params = moss.get_model(ARGS, config)
Expand Down
99 changes: 61 additions & 38 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,71 +239,97 @@ int main(int argc, char* argv[]) {
using namespace tvm::runtime;
argparse::ArgumentParser args("mlc_chat");

args.add_argument("--local-id").default_value("");
args.add_argument("--model").default_value("vicuna-v1-7b");
args.add_argument("--quantization").default_value("auto");
args.add_argument("--device-name").default_value("auto");
args.add_argument("--device_id").default_value(0).scan<'i', int>();
args.add_argument("--artifact-path").default_value("dist");
args.add_argument("--model").default_value("vicuna-v1-7b");
args.add_argument("--quantization").default_value("auto");
args.add_argument("--params").default_value("auto");
args.add_argument("--evaluate").default_value(false).implicit_value(true);

try {
args.parse_args(argc, argv);
} catch (const std::runtime_error& err) {
std::cerr << err.what() << std::endl;
std::cerr << args;
std::cerr << args << std::endl;
return 1;
}

std::string local_id = args.get<std::string>("--local-id");
std::string model = args.get<std::string>("--model");
std::string quantization = args.get<std::string>("--quantization");
std::string device_name = DetectDeviceName(args.get<std::string>("--device-name"));
int device_id = args.get<int>("--device_id");
DLDevice device = GetDevice(device_name, device_id);
std::string artifact_path = args.get<std::string>("--artifact-path");
std::string model = args.get<std::string>("--model");
std::string quantization = args.get<std::string>("--quantization");
std::string params = args.get<std::string>("--params");

std::string arch_suffix = GetArchSuffix();

std::optional<std::filesystem::path> lib_path_opt;
std::vector<std::string> local_id_candidates;
std::optional<std::filesystem::path> config_path_opt;

std::vector<std::string> quantization_candidates;
if (quantization == "auto") {
quantization_candidates = quantization_presets;
// Configure local id candidates.
if (local_id != "") {
local_id_candidates = {local_id};
} else {
quantization_candidates = {quantization};
std::vector<std::string> quantization_candidates;
if (quantization == "auto") {
quantization_candidates = quantization_presets;
} else {
quantization_candidates = {quantization};
}
for (std::string quantization_candidate : quantization_candidates) {
local_id_candidates.push_back(model + "-" + quantization_candidate);
}
}

std::optional<std::filesystem::path> lib_path;
for (auto candidate : quantization_candidates) {
std::string lib_name = model + "-" + candidate + "-" + device_name;
std::vector<std::string> search_paths = {artifact_path + "/" + model + "-" + candidate,
artifact_path + "/" + model, artifact_path + "/lib"};
// search for lib_x86_64 and lib
lib_path_opt = FindFile(search_paths, {lib_name, lib_name + arch_suffix}, GetLibSuffixes());
if (lib_path_opt) {
quantization = candidate;
// Search for mlc-llm-config.json.
for (auto local_id_candidate : local_id_candidates) {
std::vector<std::string> config_search_paths = {
artifact_path + "/" + local_id_candidate + "/params", //
artifact_path + "/prebuilt/" + local_id_candidate};
config_path_opt = FindFile(config_search_paths, {"mlc-llm-config"}, {".json"});
if (config_path_opt) {
local_id = local_id_candidate;
break;
}
}
if (!config_path_opt) {
std::cerr << "Cannot find \"mlc-llm-config.json\" in path \"" << artifact_path << "/"
<< local_id_candidates[0] << "/params/\", \"" << artifact_path
<< "/prebuilt/" + local_id_candidates[0] << "\" or other candidate paths.";
return 1;
}
std::cout << "Use config " << config_path_opt.value().string() << std::endl;
std::filesystem::path model_path = config_path_opt.value().parent_path();

// Locate the library.
std::string lib_name = local_id + "-" + device_name;
std::string lib_dir_path;
if (model_path.string().compare(model_path.string().length() - 7, 7, "/params") == 0) {
lib_dir_path = model_path.parent_path().string();
} else {
lib_dir_path = model_path.parent_path().string() + "/lib";
}
std::optional<std::filesystem::path> lib_path_opt =
FindFile({lib_dir_path}, {lib_name, lib_name + arch_suffix}, GetLibSuffixes());
if (!lib_path_opt) {
std::cerr << "Cannot find " << model << " lib in preferred path \"" << artifact_path << "/"
<< model << "-" << quantization_candidates[0] << "/" << model << "-"
<< quantization_candidates[0] << "-" << device_name << GetLibSuffixes()[0]
<< "\" or other candidate paths";
std::cerr << "Cannot find library \"" << lib_name << GetLibSuffixes().back()
<< "\" and other library candidate in " << lib_dir_path << std::endl;
return 1;
}
std::cout << "Use lib " << lib_path_opt.value().string() << std::endl;
std::string model_path = lib_path_opt.value().parent_path().string();
LOG(INFO) << "model_path = " << model_path;
// get artifact path lib name

// Locate the tokenizer files.
std::optional<std::filesystem::path> tokenizer_path_opt =
FindFile({model_path, artifact_path + "/" + model}, {"tokenizer"}, {".model", ".json"});
FindFile({model_path.string()}, {"tokenizer"}, {".model", ".json"});
if (!tokenizer_path_opt) {
// Try ByteLevelBPETokenizer
tokenizer_path_opt = FindFile({model_path, artifact_path + "/" + model}, {"vocab"}, {".json"});
tokenizer_path_opt = FindFile({model_path.string()}, {"vocab"}, {".json"});
if (!tokenizer_path_opt) {
std::cerr << "Cannot find tokenizer file in " << model_path;
std::cerr << "Cannot find tokenizer file in " << model_path.string() << std::endl;
return 1;
} else {
// GPT2 styles tokenizer needs multiple files, we need to
Expand All @@ -312,19 +338,16 @@ int main(int argc, char* argv[]) {
}
}

// Locate the params.
if (params == "auto") {
auto params_json_opt =
FindFile({model_path + "/params", artifact_path + "/" + model + "/params"},
{"ndarray-cache"}, {".json"});
auto params_json_opt = FindFile({model_path}, {"ndarray-cache"}, {".json"});
if (!params_json_opt) {
std::cerr << "Cannot find ndarray-cache.json for params in preferred path \"" << model_path
<< "/params\" and \"" << artifact_path << "/" + model << "/params.";
std::cerr << "Cannot find ndarray-cache.json for params in " << model_path << std::endl;
return 1;
}
std::string params_json = params_json_opt.value().string();
params = params_json.substr(0, params_json.length() - 18);
params = params_json_opt.value().parent_path().string();
} else if (!FindFile({params}, {"ndarray-cache"}, {".json"})) {
std::cerr << "Cannot find params/ndarray-cache.json in " << model_path;
std::cerr << "Cannot find ndarray-cache.json for params in " << params << std::endl;
return 1;
}

Expand All @@ -345,7 +368,7 @@ int main(int argc, char* argv[]) {
} catch (const std::runtime_error& err) {
// catch exception so error message
// get reported here without silently quit.
std::cerr << err.what();
std::cerr << err.what() << std::endl;
return 1;
}
return 0;
Expand Down
24 changes: 4 additions & 20 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@ class Quantization:

supported_model_types = set(["llama", "gpt_neox", "moss"])

def argparse_add_common(args: argparse.ArgumentParser) -> None:
args.add_argument(
"--quantization",
type=str,
choices=[*quantization_dict.keys()],
default=list(quantization_dict.keys())[0],
)
args.add_argument(
"--model-path",
type=str,
default=None,
help="Custom model path that contains params, tokenizer, and config"
)
args.add_argument(
"--hf-path",
type=str,
default=None,
help="Hugging Face path from which to download params, tokenizer, and config from"
)

def argparse_postproc_common(args: argparse.Namespace) -> None:
if hasattr(args, "device_name"):
Expand Down Expand Up @@ -208,7 +189,10 @@ def _is_static_shape_func(func: tvm.tir.PrimFunc):
def copy_tokenizer(args: argparse.Namespace) -> None:
for filename in os.listdir(args.model_path):
if filename.startswith("tokenizer") or filename == "vocab.json":
shutil.copy(os.path.join(args.model_path, filename), args.artifact_path)
shutil.copy(
os.path.join(args.model_path, filename),
os.path.join(args.artifact_path, "params"),
)


def parse_target(args: argparse.Namespace) -> None:
Expand Down
5 changes: 3 additions & 2 deletions tests/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ class Colors:

def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument("--local-id", type=str, required=True)
args.add_argument("--device-name", type=str, default="auto")
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--max-gen-len", type=int, default=2048)
parsed = args.parse_args()
parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1)
utils.argparse_postproc_common(parsed)
parsed.artifact_path = os.path.join(
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
Expand Down Expand Up @@ -223,7 +224,7 @@ def main():
if ARGS.debug_dump:
torch.manual_seed(12)
tokenizer = AutoTokenizer.from_pretrained(
ARGS.artifact_path, trust_remote_code=True
os.path.join(ARGS.artifact_path, "params"), trust_remote_code=True
)
tokenizer.pad_token_id = tokenizer.eos_token_id
if ARGS.model.startswith("dolly-"):
Expand Down
8 changes: 5 additions & 3 deletions tests/debug/compare_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def deploy_to_pipeline(args) -> None:
primary_device = tvm.device(args.primary_device)
const_params = utils.load_params(args.artifact_path, primary_device)
state = TestState(args)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
os.path.join(args.artifact_path, "params"), trust_remote_code=True
)

print("Tokenizing...")
inputs = tvm.nd.array(
Expand Down Expand Up @@ -177,17 +179,17 @@ def deploy_to_pipeline(args) -> None:

def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument("--local-id", type=str, required=True)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--primary-device", type=str, default="auto")
args.add_argument("--cmp-device", type=str, required=True)
args.add_argument("--prompt", type=str, default="The capital of Canada is")
args.add_argument("--time-eval", default=False, action="store_true")
args.add_argument("--skip-rounds", type=int, default=0)
parsed = args.parse_args()
parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1)
utils.argparse_postproc_common(parsed)

parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
parsed.artifact_path = os.path.join(
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
)
Expand Down
5 changes: 3 additions & 2 deletions tests/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument("--local-id", type=str, required=True)
args.add_argument("--device-name", type=str, default="auto")
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--prompt", type=str, default="The capital of Canada is")
args.add_argument("--profile", action="store_true", default=False)
parsed = args.parse_args()
parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1)
utils.argparse_postproc_common(parsed)
parsed.artifact_path = os.path.join(
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
Expand Down Expand Up @@ -91,7 +92,7 @@ def deploy_to_pipeline(args) -> None:
vm = relax.VirtualMachine(ex, device)

tokenizer = AutoTokenizer.from_pretrained(
args.artifact_path, trust_remote_code=True
os.path.join(args.artifact_path, "params"), trust_remote_code=True
)

print("Tokenizing...")
Expand Down

0 comments on commit 60a409b

Please sign in to comment.