diff --git a/oc_meta/run/fixer/prov/fix.py b/oc_meta/run/fixer/prov/fix.py index 82a71dc..62cb7a7 100644 --- a/oc_meta/run/fixer/prov/fix.py +++ b/oc_meta/run/fixer/prov/fix.py @@ -3,11 +3,11 @@ import re import zipfile from collections import defaultdict -from datetime import UTC, datetime +from datetime import datetime, timezone from multiprocessing import Pool, cpu_count from zoneinfo import ZoneInfo from typing import Dict, List, Optional, Set, Tuple -from dataclasses import dataclass +from dataclasses import dataclass, field from rdflib import ConjunctiveGraph, Literal, Namespace, URIRef from rdflib.namespace import XSD @@ -22,14 +22,23 @@ class SnapshotInfo: generation_times: List[Literal] invalidation_times: List[Literal] +@dataclass +class ModificationTracker: + modifications: Dict[str, Dict[str, List[str]]] = field(default_factory=lambda: {}) + + def add_modification(self, entity_uri: str, mod_type: str, message: str) -> None: + if entity_uri not in self.modifications: + self.modifications[entity_uri] = defaultdict(list) + self.modifications[entity_uri][mod_type].append(message) + class ProvenanceProcessor: def __init__(self): self._snapshot_number_pattern = re.compile(r'/prov/se/(\d+)$') self._default_time = Literal( - datetime(2022, 12, 20, tzinfo=UTC).isoformat(), + datetime(2022, 12, 20, tzinfo=timezone.utc).isoformat(), datatype=XSD.dateTime ) - self.modifications = defaultdict(lambda: defaultdict(list)) + self.tracker = ModificationTracker() def _extract_snapshot_number(self, snapshot_uri: str) -> int: """Extract the snapshot number from its URI using pre-compiled regex.""" @@ -48,7 +57,7 @@ def _convert_to_utc(self, timestamp_str: str) -> datetime: if dt.tzinfo is None: dt = dt.replace(tzinfo=ZoneInfo("Europe/Rome")) - return dt.astimezone(UTC) + return dt.astimezone(timezone.utc) def _normalize_timestamp(self, literal: Literal) -> Tuple[Literal, bool]: """Normalize a timestamp literal to UTC timezone.""" @@ -86,49 +95,47 @@ def _remove_multiple_timestamps(self, context: ConjunctiveGraph, snapshot_uri: URIRef, predicate: URIRef, timestamps: List[Literal]) -> None: - """Rimuove tutti i timestamp esistenti per un dato predicato.""" + """Remove all timestamps for a given predicate.""" for ts in timestamps: context.remove((snapshot_uri, predicate, ts)) - self.modifications[str(snapshot_uri)][f"Removed {predicate.split('#')[-1]}"].append( - f"{str(snapshot_uri)}: {str(ts)}") + self.tracker.add_modification( + str(snapshot_uri), + f"Removed {predicate.split('#')[-1]}", + f"{str(snapshot_uri)}: {str(ts)}" + ) - def process_file(self, prov_file_path: str) -> Optional[Tuple[str, Dict]]: + @staticmethod + def process_file(prov_file_path: str) -> Optional[Tuple[str, Dict]]: """Process a single provenance file with optimized operations.""" + processor = ProvenanceProcessor() # Create new instance for each file try: with zipfile.ZipFile(prov_file_path, 'r') as zip_ref: g = ConjunctiveGraph() - # Parse all files in a single operation for filename in zip_ref.namelist(): with zip_ref.open(filename) as file: g.parse(file, format='json-ld') modified = False - # Process each context for context in g.contexts(): context_uri = str(context.identifier) if not context_uri.endswith('/prov/'): continue - entity_uri = URIRef(self._get_entity_from_prov_graph(context_uri)) + entity_uri = URIRef(processor._get_entity_from_prov_graph(context_uri)) + snapshots = processor._collect_snapshot_info(context) - # Collect all snapshot info in a single pass - snapshots = self._collect_snapshot_info(context) - if not snapshots: - continue - - # Batch process modifications - modified |= self._process_snapshots(context, entity_uri, snapshots) + if snapshots: + modified |= processor._process_snapshots(context, entity_uri, snapshots) if modified: - # Save modifications in a single operation with zipfile.ZipFile(prov_file_path, 'w', zipfile.ZIP_DEFLATED, allowZip64=True) as zip_out: jsonld_data = g.serialize(format='json-ld', encoding='utf-8', ensure_ascii=False, indent=None) zip_out.writestr('se.json', jsonld_data) - return str(prov_file_path), dict(self.modifications) + return str(prov_file_path), processor.tracker.modifications except Exception as e: print(f"Error processing {prov_file_path}: {e}") @@ -139,13 +146,16 @@ def _process_snapshots(self, context: ConjunctiveGraph, entity_uri: URIRef, snapshots: List[SnapshotInfo]) -> bool: """Process all snapshots in batch operations.""" modified = False - mods = self.modifications[str(entity_uri)] # Process specializationOf relationships for snapshot in snapshots: if not any(context.objects(snapshot.uri, PROV.specializationOf)): context.add((snapshot.uri, PROV.specializationOf, entity_uri)) - mods["Added specializationOf"].append(str(snapshot.uri)) + self.tracker.add_modification( + str(entity_uri), + "Added specializationOf", + str(snapshot.uri) + ) modified = True # Process wasDerivedFrom relationships @@ -155,8 +165,11 @@ def _process_snapshots(self, context: ConjunctiveGraph, entity_uri: URIRef, if not any(context.objects(curr_snapshot.uri, PROV.wasDerivedFrom)): context.add((curr_snapshot.uri, PROV.wasDerivedFrom, prev_snapshot.uri)) - mods["Added wasDerivedFrom"].append( - f"{str(curr_snapshot.uri)} → {str(prev_snapshot.uri)}") + self.tracker.add_modification( + str(entity_uri), + "Added wasDerivedFrom", + f"{str(curr_snapshot.uri)} → {str(prev_snapshot.uri)}" + ) modified = True # Process timestamps @@ -185,11 +198,9 @@ def _handle_generation_time(self, context: ConjunctiveGraph, modified = False snapshot = snapshots[index] - # Se ci sono timestamp multipli o nessun timestamp, li gestiamo if len(snapshot.generation_times) != 1: new_time = None - # Rimuovi tutti i timestamp esistenti se ce ne sono if snapshot.generation_times: self._remove_multiple_timestamps( context, snapshot.uri, PROV.generatedAtTime, snapshot.generation_times) @@ -201,7 +212,6 @@ def _handle_generation_time(self, context: ConjunctiveGraph, new_time = prev_snapshot.invalidation_times[0] elif (prev_snapshot.generation_times and snapshot.invalidation_times and len(snapshot.invalidation_times) == 1): - # Calculate intermediate time prev_time = self._convert_to_utc(prev_snapshot.generation_times[0]) curr_time = self._convert_to_utc(snapshot.invalidation_times[0]) middle_time = prev_time + (curr_time - prev_time) / 2 @@ -211,47 +221,49 @@ def _handle_generation_time(self, context: ConjunctiveGraph, if new_time: context.add((snapshot.uri, PROV.generatedAtTime, new_time)) - self.modifications[str(snapshot.uri)]["Added generatedAtTime"].append( - f"{str(snapshot.uri)}: {str(new_time)}") + self.tracker.add_modification( + str(snapshot.uri), + "Added generatedAtTime", + f"{str(snapshot.uri)}: {str(new_time)}" + ) modified = True return modified def _handle_invalidation_time(self, context: ConjunctiveGraph, - snapshots: List[SnapshotInfo], index: int) -> bool: - """Handle invalidation time for a snapshot.""" - modified = False - snapshot = snapshots[index] - next_snapshot = snapshots[index + 1] + snapshots: List[SnapshotInfo], index: int) -> bool: + """Handle invalidation time for a snapshot.""" + modified = False + snapshot = snapshots[index] + next_snapshot = snapshots[index + 1] - # Gestisci timestamp multipli o mancanti - if len(snapshot.invalidation_times) != 1: - # Rimuovi tutti i timestamp esistenti se ce ne sono - if snapshot.invalidation_times: - self._remove_multiple_timestamps( - context, snapshot.uri, PROV.invalidatedAtTime, snapshot.invalidation_times) - modified = True + if len(snapshot.invalidation_times) != 1: + if snapshot.invalidation_times: + self._remove_multiple_timestamps( + context, snapshot.uri, PROV.invalidatedAtTime, snapshot.invalidation_times) + modified = True - new_time = None - if next_snapshot.generation_times: - if len(next_snapshot.generation_times) == 1: - # Caso semplice: usa l'unico generation time disponibile - new_time = next_snapshot.generation_times[0] - else: - # Caso con timestamp multipli: usa il più vecchio come punto di invalidazione - earliest_time = min( - self._convert_to_utc(ts) - for ts in next_snapshot.generation_times - ) - new_time = Literal(earliest_time.isoformat(), datatype=XSD.dateTime) + new_time = None + if next_snapshot.generation_times: + if len(next_snapshot.generation_times) == 1: + new_time = next_snapshot.generation_times[0] + else: + earliest_time = min( + self._convert_to_utc(ts) + for ts in next_snapshot.generation_times + ) + new_time = Literal(earliest_time.isoformat(), datatype=XSD.dateTime) - if new_time: - context.add((snapshot.uri, PROV.invalidatedAtTime, new_time)) - self.modifications[str(snapshot.uri)]["Added invalidatedAtTime"].append( - f"{str(snapshot.uri)}: {str(new_time)}") - modified = True - - return modified + if new_time: + context.add((snapshot.uri, PROV.invalidatedAtTime, new_time)) + self.tracker.add_modification( + str(snapshot.uri), + "Added invalidatedAtTime", + f"{str(snapshot.uri)}: {str(new_time)}" + ) + modified = True + + return modified def main(): parser = argparse.ArgumentParser(description="Fix provenance files in parallel") @@ -267,11 +279,9 @@ def main(): if file.endswith('se.zip') ] - processor = ProvenanceProcessor() - with Pool(processes=args.processes) as pool: results = list(tqdm( - pool.imap_unordered(processor.process_file, prov_files), + pool.imap_unordered(ProvenanceProcessor.process_file, prov_files), total=len(prov_files), desc="Fixing provenance files" ))