diff --git a/backend/composer/services/export_services.py b/backend/composer/services/export_services.py index ffc7a96c..d90c6f9d 100644 --- a/backend/composer/services/export_services.py +++ b/backend/composer/services/export_services.py @@ -7,13 +7,21 @@ from django.contrib.auth.models import User from django.db import transaction -from django.db.models import Count, QuerySet +from django.db.models import Count, QuerySet, Prefetch from django.utils import timezone -from composer.enums import CSState -from composer.enums import NoteType, ExportRelationships, CircuitType, Laterality, MetricEntity, DestinationType, \ - ViaType, SentenceState, \ - Projection +from composer.enums import ( + CSState, + NoteType, + ExportRelationships, + CircuitType, + Laterality, + MetricEntity, + DestinationType, + ViaType, + SentenceState, + Projection, +) from composer.exceptions import UnexportableConnectivityStatement from composer.models import ( Tag, @@ -22,10 +30,15 @@ ExportMetrics, Sentence, Specie, - Via, AnatomicalEntity, Destination, + Via, + AnatomicalEntity, + Destination, + Note, +) +from composer.services.connections_service import ( + get_complete_from_entities_for_destination, + get_complete_from_entities_for_via, ) -from composer.services.connections_service import get_complete_from_entities_for_destination, \ - get_complete_from_entities_for_via from composer.services.filesystem_service import create_dir_if_not_exists from composer.services.state_services import ConnectivityStatementStateService @@ -61,16 +74,16 @@ class Row: def __init__( - self, - structure: str, - identifier: str, - relationship: str, - predicate: str, - curation_notes: str = "", - review_notes: str = "", - layer: str = "", - connected_from_names: str = "", - connected_from_uris: str = "" + self, + structure: str, + identifier: str, + relationship: str, + predicate: str, + curation_notes: str = "", + review_notes: str = "", + layer: str = "", + connected_from_names: str = "", + connected_from_uris: str = "", ): self.structure = structure self.identifier = identifier @@ -96,14 +109,13 @@ def get_nlp_id(cs: ConnectivityStatement, row: Row): def get_neuron_population_label(cs: ConnectivityStatement, row: Row): - return ' '.join(cs.get_journey()) + return " ".join(cs.get_journey()) def get_type(cs: ConnectivityStatement, row: Row): return cs.phenotype.name if cs.phenotype else "" - def get_structure(cs: ConnectivityStatement, row: Row): return row.structure @@ -133,7 +145,7 @@ def get_predicate(cs: ConnectivityStatement, row: Row): def get_observed_in_species(cs: ConnectivityStatement, row: Row): - return ", ".join([specie.name for specie in cs.species.all()]) + return ", ".join(specie.name for specie in cs.species.all()) def escape_newlines(value): @@ -141,21 +153,22 @@ def escape_newlines(value): def get_different_from_existing(cs: ConnectivityStatement, row: Row): - return escape_newlines( - "\n".join([note.note for note in cs.notes.filter(type=NoteType.DIFFERENT)]) - ) + different_notes = [ + note.note for note in cs.prefetched_notes if note.type == NoteType.DIFFERENT + ] + return escape_newlines("\n".join(different_notes)) def get_curation_notes(cs: ConnectivityStatement, row: Row): - return escape_newlines(row.curation_notes.replace("\\", "\\\\")) + return escape_newlines(row.curation_notes) def get_review_notes(cs: ConnectivityStatement, row: Row): - return escape_newlines(row.review_notes.replace("\\", "\\\\")) + return escape_newlines(row.review_notes) def get_reference(cs: ConnectivityStatement, row: Row): - return ", ".join([procenance.uri for procenance in cs.provenance_set.all()]) + return ", ".join(procenance.uri for procenance in cs.provenance_set.all()) def is_approved_by_sawg(cs: ConnectivityStatement, row: Row): @@ -171,12 +184,12 @@ def get_added_to_sckan_timestamp(cs: ConnectivityStatement, row: Row): def has_nerve_branches(cs: ConnectivityStatement, row: Row) -> bool: - return cs.tags.filter(tag=HAS_NERVE_BRANCHES_TAG).exists() + return any(tag.tag == HAS_NERVE_BRANCHES_TAG for tag in cs.prefetched_tags) def get_tag_filter(tag_name): def tag_filter(cs, row): - return cs.tags.filter(tag=tag_name).exists() + return any(tag.tag == tag_name for tag in cs.prefetched_tags) return tag_filter @@ -203,7 +216,7 @@ def generate_csv_attributes_mapping() -> Dict[str, Callable]: "Review notes": get_review_notes, "Proposed action": get_proposed_action, "Added to SCKAN (time stamp)": get_added_to_sckan_timestamp, - 'URI': get_statement_uri, + "URI": get_statement_uri, } exportable_tags = Tag.objects.filter(exportable=True) for tag in exportable_tags: @@ -220,13 +233,14 @@ def get_origin_row(origin: AnatomicalEntity, review_notes: str, curation_notes: ExportRelationships.hasSomaLocatedIn.value, curation_notes, review_notes, - layer='1' + layer="1", ) def get_destination_row(destination: Destination, total_vias: int): - if destination.from_entities.exists(): - connected_from_entities = destination.from_entities.all() + from_entities = list(destination.from_entities.all()) + if from_entities: + connected_from_entities = from_entities else: connected_from_entities = get_complete_from_entities_for_destination(destination) @@ -243,15 +257,16 @@ def get_destination_row(destination: Destination, total_vias: int): "", layer=layer_value, connected_from_names=connected_from_names, - connected_from_uris=connected_from_uris + connected_from_uris=connected_from_uris, ) for ae in destination.anatomical_entities.all() ] def get_via_row(via: Via): - if via.from_entities.exists(): - connected_from_entities = via.from_entities.all() + from_entities = list(via.from_entities.all()) + if from_entities: + connected_from_entities = from_entities else: connected_from_entities = get_complete_from_entities_for_via(via) @@ -268,7 +283,7 @@ def get_via_row(via: Via): "", layer=layer_value, connected_from_names=connected_from_names, - connected_from_uris=connected_from_uris + connected_from_uris=connected_from_uris, ) for ae in via.anatomical_entities.all() ] @@ -276,8 +291,8 @@ def get_via_row(via: Via): def _get_connected_from_info(entities): connected_from_info = [(entity.name, entity.ontology_uri) for entity in entities] if entities else [] - connected_from_names = '; '.join(name for name, _ in connected_from_info) - connected_from_uris = '; '.join(uri for _, uri in connected_from_info) + connected_from_names = "; ".join(name for name, _ in connected_from_info) + connected_from_uris = "; ".join(uri for _, uri in connected_from_info) return connected_from_names, connected_from_uris @@ -314,7 +329,7 @@ def get_circuit_role_row(cs: ConnectivityStatement): ) -def get_laterality_row(cs: ConnectivityStatement): +def get_projection_row(cs: ConnectivityStatement): return Row( cs.get_projection_display(), TEMP_PROJECTION_MAP.get(cs.projection, ""), @@ -350,7 +365,7 @@ def get_phenotype_row(cs: ConnectivityStatement): def get_projection_phenotype_row(cs: ConnectivityStatement): - projection_phenotype = cs.projection_phenotype if cs.projection_phenotype else "" + projection_phenotype = cs.projection_phenotype.name if cs.projection_phenotype else "" projection_phenotype_ontology_uri = cs.projection_phenotype.ontology_uri if cs.projection_phenotype else "" return Row( @@ -365,7 +380,7 @@ def get_projection_phenotype_row(cs: ConnectivityStatement): def get_functional_circuit_row(cs: ConnectivityStatement): return Row( - cs.functional_circuit_role, + cs.functional_circuit_role.name, cs.functional_circuit_role.ontology_uri, ExportRelationships.hasFunctionalCircuitRolePhenotype.label, ExportRelationships.hasFunctionalCircuitRolePhenotype.value, @@ -381,93 +396,80 @@ def get_forward_connection_row(forward_conn: ConnectivityStatement): ExportRelationships.hasForwardConnection.label, ExportRelationships.hasForwardConnection.value, "", - "" + "", ) -def get_rows(cs: ConnectivityStatement) -> List: +def get_rows(cs: ConnectivityStatement) -> List[Row]: rows = [] - review_notes = "\n".join( - [note.note for note in cs.notes.filter(type=NoteType.PLAIN)] + # Use prefetched notes + plain_notes = [ + note.note for note in cs.prefetched_notes if note.type == NoteType.PLAIN + ] + review_notes = "\n".join(plain_notes) + curation_notes = "\n".join( + note.note for note in cs.sentence.prefetched_sentence_notes ) - curation_notes = "\n".join([note.note for note in cs.sentence.notes.all()]) - for origin in cs.origins.all(): - try: - origin_row = get_origin_row(origin, review_notes, curation_notes) - rows.append(origin_row) - except Exception: - raise UnexportableConnectivityStatement("Error getting origin row") - - for via in cs.via_set.all().order_by("order"): - try: - via_rows = get_via_row(via) - rows.extend(via_rows) - except Exception: - raise UnexportableConnectivityStatement("Error getting via row") - - total_vias = cs.via_set.count() - for destination in cs.destinations.all(): - try: - destination_rows = get_destination_row(destination, total_vias) - rows.extend(destination_rows) - except Exception: - raise UnexportableConnectivityStatement("Error getting destination row") + # Origins + origins = cs.origins.all() + for origin in origins: + origin_row = get_origin_row(origin, review_notes, curation_notes) + rows.append(origin_row) + + # Vias (ordered by 'order' attribute) + vias = cs.via_set.all().order_by("order") + total_vias = vias.count() + for via in vias: + via_rows = get_via_row(via) + rows.extend(via_rows) + + # Destinations + destinations = cs.destinations.all() + for destination in destinations: + destination_rows = get_destination_row(destination, total_vias) + rows.extend(destination_rows) + + # Species for specie in cs.species.all(): - try: - rows.append(get_specie_row(specie)) - except Exception: - raise UnexportableConnectivityStatement("Error getting specie row") + rows.append(get_specie_row(specie)) + # Sex if cs.sex is not None: - try: - rows.append(get_sex_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting sex row") + rows.append(get_sex_row(cs)) - try: + # Circuit Role + if cs.circuit_type is not None: rows.append(get_circuit_role_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting circuit type row") - try: - rows.append(get_laterality_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting laterality row") + # Projection + if cs.projection is not None: + rows.append(get_projection_row(cs)) - try: + # Soma Phenotype + if cs.laterality is not None: rows.append(get_soma_phenotype_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting soma phenotype row") - try: + # Phenotype + if cs.phenotype is not None: rows.append(get_phenotype_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting phenotype row") + # Projection Phenotype if cs.projection_phenotype: - try: - rows.append(get_projection_phenotype_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting projection phenotype row") + rows.append(get_projection_phenotype_row(cs)) + # Functional Circuit Role if cs.functional_circuit_role: - try: - rows.append(get_functional_circuit_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting functinal circuit role row") + rows.append(get_functional_circuit_row(cs)) + # Forward Connections for forward_conn in cs.forward_connection.all(): - try: - rows.append(get_forward_connection_row(forward_conn)) - except Exception: - raise UnexportableConnectivityStatement("Error getting forward connection row") + rows.append(get_forward_connection_row(forward_conn)) return rows def create_export_batch(qs: QuerySet, user: User) -> ExportBatch: - # do transition to EXPORTED state export_batch = ExportBatch.objects.create(user=user) export_batch.connectivity_statements.set(qs) export_batch.save() @@ -475,14 +477,15 @@ def create_export_batch(qs: QuerySet, user: User) -> ExportBatch: def compute_metrics(export_batch: ExportBatch): - # will be executed by post_save signal on ExportBatch - last_export_batch = ExportBatch.objects.exclude(id=export_batch.id).order_by("-created_at").first() + last_export_batch = ( + ExportBatch.objects.exclude(id=export_batch.id).order_by("-created_at").first() + ) if last_export_batch: last_export_batch_created_at = last_export_batch.created_at else: last_export_batch_created_at = None - # compute the metrics for this export + # Compute the metrics for this export if last_export_batch_created_at: sentences_created_qs = Sentence.objects.filter( created_date__gt=last_export_batch_created_at, @@ -497,18 +500,20 @@ def compute_metrics(export_batch: ExportBatch): ) else: connectivity_statements_created_qs = ConnectivityStatement.objects.all() - connectivity_statements_created_qs.exclude(state=CSState.DRAFT) # skip draft statements + connectivity_statements_created_qs = connectivity_statements_created_qs.exclude( + state=CSState.DRAFT + ) # skip draft statements export_batch.connectivity_statements_created = connectivity_statements_created_qs.count() - # export_batch.save() - - # compute the state metrics for this export - connectivity_statement_metrics = list(ConnectivityStatement.objects.values("state").annotate(count=Count("state"))) + # Compute the state metrics for this export + connectivity_statement_metrics = list( + ConnectivityStatement.objects.values("state").annotate(count=Count("state")) + ) for state in CSState: - try: - metric = [x for x in connectivity_statement_metrics if x.get("state") == state][0] - except IndexError: - metric = {"state": state.value, "count": 0} + metric = next( + (x for x in connectivity_statement_metrics if x.get("state") == state), + {"state": state.value, "count": 0}, + ) ExportMetrics.objects.create( export_batch=export_batch, entity=MetricEntity.CONNECTIVITY_STATEMENT, @@ -517,23 +522,23 @@ def compute_metrics(export_batch: ExportBatch): ) sentence_metrics = list(Sentence.objects.values("state").annotate(count=Count("state"))) for state in SentenceState: - try: - metric = [x for x in sentence_metrics if x.get("state") == state][0] - except IndexError: - metric = {"state": state.value, "count": 0} + metric = next( + (x for x in sentence_metrics if x.get("state") == state), + {"state": state.value, "count": 0}, + ) ExportMetrics.objects.create( export_batch=export_batch, entity=MetricEntity.SENTENCE, state=SentenceState(metric["state"]), count=metric["count"], ) - # ExportMetrics return export_batch def do_transition_to_exported(export_batch: ExportBatch, user: User): system_user = User.objects.get(username="system") - for connectivity_statement in export_batch.connectivity_statements.all(): + connectivity_statements = export_batch.connectivity_statements.all() + for connectivity_statement in connectivity_statements: available_transitions = [ available_state.target for available_state in connectivity_statement.get_available_user_state_transitions( @@ -541,7 +546,6 @@ def do_transition_to_exported(export_batch: ExportBatch, user: User): ) ] if CSState.EXPORTED in available_transitions: - # we need to update the state to exported when we are in the NP0 approved state and the system user has the permission to do so cs = ConnectivityStatementStateService(connectivity_statement).do_transition( CSState.EXPORTED, system_user, user ) @@ -549,7 +553,6 @@ def do_transition_to_exported(export_batch: ExportBatch, user: User): def dump_export_batch(export_batch, folder_path: typing.Optional[str] = None) -> str: - # returns the path of the exported file if folder_path is None: folder_path = tempfile.gettempdir() @@ -560,36 +563,65 @@ def dump_export_batch(export_batch, folder_path: typing.Optional[str] = None) -> csv_attributes_mapping = generate_csv_attributes_mapping() + # Prefetch related data with filters + notes_prefetch = Prefetch( + "notes", + queryset=Note.objects.filter(type__in=[NoteType.PLAIN, NoteType.DIFFERENT]), + to_attr="prefetched_notes", + ) + sentence_notes_prefetch = Prefetch( + "sentence__notes", + queryset=Note.objects.all(), + to_attr="prefetched_sentence_notes", + ) + tags_prefetch = Prefetch( + "tags", queryset=Tag.objects.all(), to_attr="prefetched_tags" + ) + + connectivity_statements = export_batch.connectivity_statements.select_related( + "sentence", "sex", "functional_circuit_role", "projection_phenotype" + ).prefetch_related( + "origins", + notes_prefetch, + tags_prefetch, + "species", + "forward_connection", + "provenance_set", + sentence_notes_prefetch, + "via_set__anatomical_entities", + "via_set__from_entities", + "destinations__anatomical_entities", + "destinations__from_entities", + ) + with open(filepath, "w", newline="") as csvfile: writer = csv.writer(csvfile) - # Write header row headers = csv_attributes_mapping.keys() writer.writerow(headers) - # Write data rows - for obj in export_batch.connectivity_statements.all(): + for cs in connectivity_statements: try: - rows = get_rows(obj) + rows = get_rows(cs) except UnexportableConnectivityStatement as e: logging.warning( - f"Connectivity Statement with id {obj.id} skipped due to {e}" + f"Connectivity Statement with id {cs.id} skipped due to {e}" ) continue + for row in rows: - row_content = [] - for key in csv_attributes_mapping: - row_content.append(csv_attributes_mapping[key](obj, row)) + row_content = [func(cs, row) for func in csv_attributes_mapping.values()] writer.writerow(row_content) + return filepath def export_connectivity_statements( - qs: QuerySet, user: User, folder_path: typing.Optional[str] + qs: QuerySet, user: User, folder_path: typing.Optional[str] ) -> typing.Tuple[str, ExportBatch]: with transaction.atomic(): - # make sure create_export_batch and do_transition_to_exported are in one database transaction + # Ensure create_export_batch and do_transition_to_exported are in one database transaction export_batch = create_export_batch(qs, user) do_transition_to_exported(export_batch, user) export_file = dump_export_batch(export_batch, folder_path) - return export_file, export_batch + return export_file, export_batch \ No newline at end of file diff --git a/backend/composer/templates/admin/index.html b/backend/composer/templates/admin/index.html index 6b3952b3..73399f4f 100644 --- a/backend/composer/templates/admin/index.html +++ b/backend/composer/templates/admin/index.html @@ -30,12 +30,16 @@
Export statistics
- Create new export - + + Create new export + +