Skip to content

Commit

Permalink
changed database
Browse files Browse the repository at this point in the history
  • Loading branch information
Kathryn-cat committed Jun 10, 2022
1 parent 7db45ee commit 69fe933
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 123 deletions.
104 changes: 0 additions & 104 deletions ;w

This file was deleted.

31 changes: 12 additions & 19 deletions python/tvm/meta_schedule/testing/dataset_sample_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,6 @@ def sample_candidates(task, task_name, model_name):
evolve_with_cost_model = tvm.get_global_func(
"meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel"
)
tuning_record_path = os.path.join(
args.candidate_cache_dir,
model_name,
task_name + "_tuning_record.json",
)
workload_path = os.path.join(
args.candidate_cache_dir,
model_name,
task_name + "_workload.json",
)
database = ms.database.JSONDatabase(
path_workload=workload_path,
path_tuning_record=tuning_record_path,
)

strategy = ms.search_strategy.EvolutionarySearch(
num_trials_per_iter=args.num_trials_per_iter,
max_trials_per_task=args.max_trials_per_task,
Expand All @@ -143,7 +128,7 @@ def sample_candidates(task, task_name, model_name):
context.initialize()
context.pre_tuning(
context.generate_design_space(),
database=database,
database=ms.database.MemoryDatabase(),
cost_model=ms.cost_model.RandomModel(), # type: ignore
)

Expand All @@ -162,9 +147,17 @@ def sample_candidates(task, task_name, model_name):
itr += 1
all_states = all_states[: args.num_samples_per_task]

workload = database.commit_workload(mod=task)
for state in all_states:
database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload))
workload = ms.database.Workload(context.mod)
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json")
with open(file_path, "w", encoding="utf8") as file:
for i, state in enumerate(all_states):
tuning_record = ms.database.TuningRecord(state.trace, workload)
json_str = json.dumps(tuning_record.as_json())
assert "\n" not in json_str, "Failed to generate single line string."
if i == len(all_states) - 1:
file.write(json_str)
else:
file.write(json_str + "\n")


args = _parse_args() # pylint: disable=invalid-name
Expand Down

0 comments on commit 69fe933

Please sign in to comment.