Skip to content

Commit

Permalink
UTC -> timezone.utc, pickable modifications in prov fix
Browse files Browse the repository at this point in the history
  • Loading branch information
arcangelo7 committed Oct 25, 2024
1 parent 1de4031 commit 58ad888
Showing 1 changed file with 74 additions and 64 deletions.
138 changes: 74 additions & 64 deletions oc_meta/run/fixer/prov/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import re
import zipfile
from collections import defaultdict
from datetime import UTC, datetime
from datetime import datetime, timezone
from multiprocessing import Pool, cpu_count
from zoneinfo import ZoneInfo
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass
from dataclasses import dataclass, field

from rdflib import ConjunctiveGraph, Literal, Namespace, URIRef
from rdflib.namespace import XSD
Expand All @@ -22,14 +22,23 @@ class SnapshotInfo:
generation_times: List[Literal]
invalidation_times: List[Literal]

@dataclass
class ModificationTracker:
modifications: Dict[str, Dict[str, List[str]]] = field(default_factory=lambda: {})

def add_modification(self, entity_uri: str, mod_type: str, message: str) -> None:
if entity_uri not in self.modifications:
self.modifications[entity_uri] = defaultdict(list)
self.modifications[entity_uri][mod_type].append(message)

class ProvenanceProcessor:
def __init__(self):
self._snapshot_number_pattern = re.compile(r'/prov/se/(\d+)$')
self._default_time = Literal(
datetime(2022, 12, 20, tzinfo=UTC).isoformat(),
datetime(2022, 12, 20, tzinfo=timezone.utc).isoformat(),
datatype=XSD.dateTime
)
self.modifications = defaultdict(lambda: defaultdict(list))
self.tracker = ModificationTracker()

def _extract_snapshot_number(self, snapshot_uri: str) -> int:
"""Extract the snapshot number from its URI using pre-compiled regex."""
Expand All @@ -48,7 +57,7 @@ def _convert_to_utc(self, timestamp_str: str) -> datetime:
if dt.tzinfo is None:
dt = dt.replace(tzinfo=ZoneInfo("Europe/Rome"))

return dt.astimezone(UTC)
return dt.astimezone(timezone.utc)

def _normalize_timestamp(self, literal: Literal) -> Tuple[Literal, bool]:
"""Normalize a timestamp literal to UTC timezone."""
Expand Down Expand Up @@ -86,49 +95,47 @@ def _remove_multiple_timestamps(self, context: ConjunctiveGraph,
snapshot_uri: URIRef,
predicate: URIRef,
timestamps: List[Literal]) -> None:
"""Rimuove tutti i timestamp esistenti per un dato predicato."""
"""Remove all timestamps for a given predicate."""
for ts in timestamps:
context.remove((snapshot_uri, predicate, ts))
self.modifications[str(snapshot_uri)][f"Removed {predicate.split('#')[-1]}"].append(
f"{str(snapshot_uri)}: {str(ts)}")
self.tracker.add_modification(
str(snapshot_uri),
f"Removed {predicate.split('#')[-1]}",
f"{str(snapshot_uri)}: {str(ts)}"
)

def process_file(self, prov_file_path: str) -> Optional[Tuple[str, Dict]]:
@staticmethod
def process_file(prov_file_path: str) -> Optional[Tuple[str, Dict]]:
"""Process a single provenance file with optimized operations."""
processor = ProvenanceProcessor() # Create new instance for each file
try:
with zipfile.ZipFile(prov_file_path, 'r') as zip_ref:
g = ConjunctiveGraph()

# Parse all files in a single operation
for filename in zip_ref.namelist():
with zip_ref.open(filename) as file:
g.parse(file, format='json-ld')

modified = False

# Process each context
for context in g.contexts():
context_uri = str(context.identifier)
if not context_uri.endswith('/prov/'):
continue

entity_uri = URIRef(self._get_entity_from_prov_graph(context_uri))
entity_uri = URIRef(processor._get_entity_from_prov_graph(context_uri))
snapshots = processor._collect_snapshot_info(context)

# Collect all snapshot info in a single pass
snapshots = self._collect_snapshot_info(context)
if not snapshots:
continue

# Batch process modifications
modified |= self._process_snapshots(context, entity_uri, snapshots)
if snapshots:
modified |= processor._process_snapshots(context, entity_uri, snapshots)

if modified:
# Save modifications in a single operation
with zipfile.ZipFile(prov_file_path, 'w', zipfile.ZIP_DEFLATED, allowZip64=True) as zip_out:
jsonld_data = g.serialize(format='json-ld', encoding='utf-8',
ensure_ascii=False, indent=None)
zip_out.writestr('se.json', jsonld_data)

return str(prov_file_path), dict(self.modifications)
return str(prov_file_path), processor.tracker.modifications

except Exception as e:
print(f"Error processing {prov_file_path}: {e}")
Expand All @@ -139,13 +146,16 @@ def _process_snapshots(self, context: ConjunctiveGraph, entity_uri: URIRef,
snapshots: List[SnapshotInfo]) -> bool:
"""Process all snapshots in batch operations."""
modified = False
mods = self.modifications[str(entity_uri)]

# Process specializationOf relationships
for snapshot in snapshots:
if not any(context.objects(snapshot.uri, PROV.specializationOf)):
context.add((snapshot.uri, PROV.specializationOf, entity_uri))
mods["Added specializationOf"].append(str(snapshot.uri))
self.tracker.add_modification(
str(entity_uri),
"Added specializationOf",
str(snapshot.uri)
)
modified = True

# Process wasDerivedFrom relationships
Expand All @@ -155,8 +165,11 @@ def _process_snapshots(self, context: ConjunctiveGraph, entity_uri: URIRef,

if not any(context.objects(curr_snapshot.uri, PROV.wasDerivedFrom)):
context.add((curr_snapshot.uri, PROV.wasDerivedFrom, prev_snapshot.uri))
mods["Added wasDerivedFrom"].append(
f"{str(curr_snapshot.uri)}{str(prev_snapshot.uri)}")
self.tracker.add_modification(
str(entity_uri),
"Added wasDerivedFrom",
f"{str(curr_snapshot.uri)}{str(prev_snapshot.uri)}"
)
modified = True

# Process timestamps
Expand Down Expand Up @@ -185,11 +198,9 @@ def _handle_generation_time(self, context: ConjunctiveGraph,
modified = False
snapshot = snapshots[index]

# Se ci sono timestamp multipli o nessun timestamp, li gestiamo
if len(snapshot.generation_times) != 1:
new_time = None

# Rimuovi tutti i timestamp esistenti se ce ne sono
if snapshot.generation_times:
self._remove_multiple_timestamps(
context, snapshot.uri, PROV.generatedAtTime, snapshot.generation_times)
Expand All @@ -201,7 +212,6 @@ def _handle_generation_time(self, context: ConjunctiveGraph,
new_time = prev_snapshot.invalidation_times[0]
elif (prev_snapshot.generation_times and
snapshot.invalidation_times and len(snapshot.invalidation_times) == 1):
# Calculate intermediate time
prev_time = self._convert_to_utc(prev_snapshot.generation_times[0])
curr_time = self._convert_to_utc(snapshot.invalidation_times[0])
middle_time = prev_time + (curr_time - prev_time) / 2
Expand All @@ -211,47 +221,49 @@ def _handle_generation_time(self, context: ConjunctiveGraph,

if new_time:
context.add((snapshot.uri, PROV.generatedAtTime, new_time))
self.modifications[str(snapshot.uri)]["Added generatedAtTime"].append(
f"{str(snapshot.uri)}: {str(new_time)}")
self.tracker.add_modification(
str(snapshot.uri),
"Added generatedAtTime",
f"{str(snapshot.uri)}: {str(new_time)}"
)
modified = True

return modified

def _handle_invalidation_time(self, context: ConjunctiveGraph,
snapshots: List[SnapshotInfo], index: int) -> bool:
"""Handle invalidation time for a snapshot."""
modified = False
snapshot = snapshots[index]
next_snapshot = snapshots[index + 1]
snapshots: List[SnapshotInfo], index: int) -> bool:
"""Handle invalidation time for a snapshot."""
modified = False
snapshot = snapshots[index]
next_snapshot = snapshots[index + 1]

# Gestisci timestamp multipli o mancanti
if len(snapshot.invalidation_times) != 1:
# Rimuovi tutti i timestamp esistenti se ce ne sono
if snapshot.invalidation_times:
self._remove_multiple_timestamps(
context, snapshot.uri, PROV.invalidatedAtTime, snapshot.invalidation_times)
modified = True
if len(snapshot.invalidation_times) != 1:
if snapshot.invalidation_times:
self._remove_multiple_timestamps(
context, snapshot.uri, PROV.invalidatedAtTime, snapshot.invalidation_times)
modified = True

new_time = None
if next_snapshot.generation_times:
if len(next_snapshot.generation_times) == 1:
# Caso semplice: usa l'unico generation time disponibile
new_time = next_snapshot.generation_times[0]
else:
# Caso con timestamp multipli: usa il più vecchio come punto di invalidazione
earliest_time = min(
self._convert_to_utc(ts)
for ts in next_snapshot.generation_times
)
new_time = Literal(earliest_time.isoformat(), datatype=XSD.dateTime)
new_time = None
if next_snapshot.generation_times:
if len(next_snapshot.generation_times) == 1:
new_time = next_snapshot.generation_times[0]
else:
earliest_time = min(
self._convert_to_utc(ts)
for ts in next_snapshot.generation_times
)
new_time = Literal(earliest_time.isoformat(), datatype=XSD.dateTime)

if new_time:
context.add((snapshot.uri, PROV.invalidatedAtTime, new_time))
self.modifications[str(snapshot.uri)]["Added invalidatedAtTime"].append(
f"{str(snapshot.uri)}: {str(new_time)}")
modified = True

return modified
if new_time:
context.add((snapshot.uri, PROV.invalidatedAtTime, new_time))
self.tracker.add_modification(
str(snapshot.uri),
"Added invalidatedAtTime",
f"{str(snapshot.uri)}: {str(new_time)}"
)
modified = True

return modified

def main():
parser = argparse.ArgumentParser(description="Fix provenance files in parallel")
Expand All @@ -267,11 +279,9 @@ def main():
if file.endswith('se.zip')
]

processor = ProvenanceProcessor()

with Pool(processes=args.processes) as pool:
results = list(tqdm(
pool.imap_unordered(processor.process_file, prov_files),
pool.imap_unordered(ProvenanceProcessor.process_file, prov_files),
total=len(prov_files),
desc="Fixing provenance files"
))
Expand Down

0 comments on commit 58ad888

Please sign in to comment.