diff --git a/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py b/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py index 920340cb..2bce54dc 100644 --- a/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py +++ b/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py @@ -57,41 +57,37 @@ def upgrade() -> None: if label.get("key"): label_dict[label.get("key")] = label.get("value") - # replace key with value in the sample of current task if label_dict is not null - if len(label_dict) > 0: - # get the sample data items of the task id - sample_items = session.execute( - f"SELECT id, data FROM task_sample WHERE task_id={task_id}" - ) - for sample_item in sample_items: - sample_id = sample_item[0] - sample_data_item = json.loads(sample_item[1]) - sample_annotated_result = json.loads(sample_data_item.get("result")) - if sample_annotated_result: - for sample_tool in sample_annotated_result.keys(): - if sample_tool.endswith("Tool"): - for sample_tool_result in sample_annotated_result.get( - sample_tool - ).get("result", []): - tool_label = sample_tool_result.get("attribute", "") - if tool_label in label_dict: - sample_tool_result["attribute"] = label_dict[ - tool_label - ] - sample_data_item["result"] = json.dumps( - sample_annotated_result, ensure_ascii=False - ) - sample_annotated_item_str = json.dumps( - sample_data_item, ensure_ascii=False - ) - op.execute( - update(task_sample) - .where(task_sample.id == sample_id) - .where( - task_sample.task_id == task_id, - ) - .values({task_sample.data: sample_annotated_item_str}) - ) + # replace key with value in the sample of current task + # get the sample data items of the task id + sample_items = session.execute( + f"SELECT id, data FROM task_sample WHERE task_id={task_id}" + ) + for sample_item in sample_items: + sample_id = sample_item[0] + sample_data_item = json.loads(sample_item[1]) + sample_annotated_result = json.loads(sample_data_item.get("result")) + if sample_annotated_result: + for sample_tool in sample_annotated_result.keys(): + if sample_tool.endswith("Tool"): + for sample_tool_result in sample_annotated_result.get( + sample_tool + ).get("result", []): + tool_label = sample_tool_result.get("attribute", "") + if tool_label in label_dict: + sample_tool_result["attribute"] = label_dict[ + tool_label + ] + sample_data_item["result"] = json.dumps( + sample_annotated_result, ensure_ascii=False + ) + sample_annotated_item_str = json.dumps( + sample_data_item, ensure_ascii=False + ) + op.execute( + update(task_sample) + .where(task_sample.id == sample_id) + .values({task_sample.data: sample_annotated_item_str}) + ) def downgrade() -> None: