From bd451753724f368a922f96c3c8990b0caffb7f0c Mon Sep 17 00:00:00 2001 From: mainlyIt Date: Wed, 5 Feb 2025 17:47:18 +0100 Subject: [PATCH] import & export obs (#347) --- vespadb/observations/views.py | 197 +++++++++++++++++++++++++++++----- 1 file changed, 168 insertions(+), 29 deletions(-) diff --git a/vespadb/observations/views.py b/vespadb/observations/views.py index 111a110..151804e 100644 --- a/vespadb/observations/views.py +++ b/vespadb/observations/views.py @@ -16,6 +16,8 @@ from django.db.models import Model from csv import writer as _writer from django.db.models.query import QuerySet +from django.contrib.gis.geos import Point +from dateutil import parser from django.views.decorators.csrf import csrf_exempt from django.contrib.gis.db.models.functions import Transform @@ -55,9 +57,10 @@ from vespadb.observations.helpers import parse_and_convert_to_utc from vespadb.observations.models import Municipality, Observation, Province, Export from vespadb.observations.models import Export +from vespadb.observations.tasks.export_utils import generate_rows from vespadb.observations.tasks.generate_export import generate_export from vespadb.observations.serializers import ObservationSerializer, MunicipalitySerializer, ProvinceSerializer - +from vespadb.observations.utils import check_if_point_in_anb_area, get_municipality_from_coordinates from django.utils.decorators import method_decorator from django_ratelimit.decorators import ratelimit from rest_framework.decorators import action @@ -81,12 +84,6 @@ def write(self, value: Any) -> Any: GEOJSON_REDIS_CACHE_EXPIRATION = 900 # 15 minutes GET_REDIS_CACHE_EXPIRATION = 86400 # 1 day BATCH_SIZE = 150 -CSV_HEADERS = [ - "id", "created_datetime", "modified_datetime", "latitude", "longitude", "source", "source_id", - "nest_height", "nest_size", "nest_location", "nest_type", "observation_datetime", - "province", "eradication_date", "municipality", "images", "anb_domain", - "notes", "eradication_result", "wn_id", "wn_validation_status", "nest_status" -] class ObservationsViewSet(ModelViewSet): # noqa: PLR0904 """ViewSet for the Observation model.""" @@ -580,21 +577,144 @@ def validate_location(self, location: str) -> GEOSGeometry: def process_data(self, data: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Process and validate the incoming data.""" + logger.info("Starting to process import data") + valid_observations = [] errors = [] + current_time = now() + for data_item in data: try: - cleaned_item = self.clean_data(data_item) - serializer = ObservationSerializer(data=cleaned_item) - if serializer.is_valid(): - valid_observations.append(serializer.validated_data) - else: - errors.append(serializer.errors) + logger.info(f"Processing data item: {data_item}") + + # Issue #294 - Only allow specific fields in import + allowed_fields = { + 'id', 'source_id', 'observation_datetime', 'eradication_problems', + 'source', 'eradication_notes', 'images', 'created_datetime', + 'longitude', 'latitude', 'eradication_persons', 'nest_size', + 'visible', 'nest_location', 'eradication_date', 'eradication_product', + 'nest_type', 'eradicator_name', 'eradication_method', + 'eradication_aftercare', 'public_domain', 'eradication_duration', + 'nest_height', 'eradication_result', 'notes', 'admin_notes' + } + + # Filter out non-allowed fields + data_item = {k: v for k, v in data_item.items() if k in allowed_fields} + + # Issue #297 - Handle record updates vs inserts + observation_id = data_item.pop('id', None) # Remove id from data_item if it exists + + if observation_id is not None: # Update existing record + try: + existing_obj = Observation.objects.get(id=observation_id) + logger.info(f"Updating existing observation with id {observation_id}") + + # Don't modify created_by and created_datetime for existing records + data_item.pop('created_by', None) + data_item.pop('created_datetime', None) + + # Set modified_by to import user and modified_datetime to current time + data_item['modified_by'] = self.request.user + data_item['modified_datetime'] = current_time + + # Issue #290 - Auto-determine province, municipality and anb + # Handle coordinates for updates + if 'longitude' in data_item and 'latitude' in data_item: + try: + long = float(data_item.pop('longitude')) + lat = float(data_item.pop('latitude')) + data_item['location'] = Point(long, lat, srid=4326) + logger.info(f"Created point from coordinates: {long}, {lat}") + + # Determine municipality, province and anb status + municipality = get_municipality_from_coordinates(long, lat) + if municipality: + data_item['municipality'] = municipality.id + if municipality.province: + data_item['province'] = municipality.province.id + data_item['anb'] = check_if_point_in_anb_area(long, lat) + + logger.info(f"Municipality ID: {data_item.get('municipality')}, Province ID: {data_item.get('province')}, ANB: {data_item['anb']}") + except (ValueError, TypeError) as e: + logger.error(f"Invalid coordinates: {e}") + errors.append({"error": f"Invalid coordinates: {str(e)}"}) + continue + + # Issue #292 - Fix timezone handling for eradication_date + if 'eradication_date' in data_item: + date_str = data_item['eradication_date'] + if isinstance(date_str, str): + try: + data_item['eradication_date'] = datetime.datetime.strptime(date_str, '%Y-%m-%d').date() + except ValueError: + errors.append({"error": f"Invalid date format for eradication_date: {date_str}"}) + continue + + for key, value in data_item.items(): + setattr(existing_obj, key, value) + existing_obj.save() + valid_observations.append(existing_obj) + continue + except Observation.DoesNotExist: + logger.error(f"Observation with id {observation_id} not found") + errors.append({"error": f"Observation with id {observation_id} not found"}) + continue + else: # New record + # Set created_by to import user + data_item['created_by'] = self.request.user + + # Set created_datetime to provided value or current time + if 'created_datetime' not in data_item: + data_item['created_datetime'] = current_time + + # Always set modified_by and modified_datetime for new records + data_item['modified_by'] = self.request.user + data_item['modified_datetime'] = current_time + + # Handle coordinates for new records + if 'longitude' in data_item and 'latitude' in data_item: + try: + long = float(data_item.pop('longitude')) + lat = float(data_item.pop('latitude')) + data_item['location'] = Point(long, lat, srid=4326) + logger.info(f"Created point from coordinates: {long}, {lat}") + + # Determine municipality, province and anb status + municipality = get_municipality_from_coordinates(long, lat) + if municipality: + data_item['municipality'] = municipality.id + if municipality.province: + data_item['province'] = municipality.province.id + data_item['anb'] = check_if_point_in_anb_area(long, lat) + + logger.info(f"Municipality ID: {data_item.get('municipality')}, Province ID: {data_item.get('province')}, ANB: {data_item['anb']}") + except (ValueError, TypeError) as e: + logger.error(f"Invalid coordinates: {e}") + errors.append({"error": f"Invalid coordinates: {str(e)}"}) + continue + + # Issue #292 - Fix timezone handling for eradication_date + if 'eradication_date' in data_item: + date_str = data_item['eradication_date'] + if isinstance(date_str, str): + try: + data_item['eradication_date'] = datetime.datetime.strptime(date_str, '%Y-%m-%d').date() + except ValueError: + errors.append({"error": f"Invalid date format for eradication_date: {date_str}"}) + continue + + cleaned_item = self.clean_data(data_item) + serializer = ObservationSerializer(data=cleaned_item) + if serializer.is_valid(): + valid_observations.append(serializer.validated_data) + else: + errors.append(serializer.errors) except Exception as e: logger.exception(f"Error processing data item: {data_item} - {e}") errors.append({"error": str(e)}) + return valid_observations, errors - + def clean_data(self, data_dict: dict[str, Any]) -> dict[str, Any]: """Clean the incoming data and remove empty or None values.""" logger.info("Original data item: %s", data_dict) @@ -610,13 +730,15 @@ def clean_data(self, data_dict: dict[str, Any]) -> dict[str, Any]: ] for field in datetime_fields: if data_dict.get(field): + # Keep ISO format strings as-is if isinstance(data_dict[field], str): try: - data_dict[field] = parse_and_convert_to_utc(data_dict[field]).isoformat() + # Just validate the format but keep original value + parser.parse(data_dict[field]) except (ValueError, TypeError): logger.exception(f"Invalid datetime format for {field}: {data_dict[field]}") data_dict.pop(field, None) - elif isinstance(data_dict[field], datetime.datetime): + elif isinstance(data_dict[field], datetime): data_dict[field] = data_dict[field].isoformat() else: data_dict.pop(field, None) @@ -627,22 +749,35 @@ def clean_data(self, data_dict: dict[str, Any]) -> dict[str, Any]: if not data_dict.get(field): data_dict[field] = None - cleaned_data = {k: v for k, v in data_dict.items() if v not in [None, ""]} # noqa: PLR6201 + cleaned_data = {k: v for k, v in data_dict.items() if v not in [None, ""]} logger.info("Cleaned data item: %s", cleaned_data) return cleaned_data - def save_observations(self, valid_data: list[dict[str, Any]]) -> Response: + def save_observations(self, valid_data: list[Union[dict[str, Any], Observation]]) -> Response: """Save the valid observations to the database.""" try: + logger.info(f"Saving {len(valid_data)} valid observations") with transaction.atomic(): - Observation.objects.bulk_create([Observation(**data) for data in valid_data]) + created_count = 0 + for data in valid_data: + if isinstance(data, Observation): + # If it's already an Observation instance, just save it + data.save() + else: + # If it's a dictionary, create a new Observation instance + obs = Observation.objects.create(**data) + created_count += 1 + + invalidate_geojson_cache() return Response( - {"message": f"Successfully imported {len(valid_data)} observations."}, status=status.HTTP_201_CREATED + {"message": f"Successfully imported {created_count} observations."}, + status=status.HTTP_201_CREATED ) except IntegrityError as e: logger.exception("Error during bulk import") return Response( - {"error": f"An error occurred during bulk import: {e!s}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + {"error": f"An error occurred during bulk import: {e!s}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR ) @method_decorator(ratelimit(key="ip", rate="60/m", method="GET", block=True)) @@ -780,11 +915,11 @@ def download_export(self, request: HttpRequest) -> Union[StreamingHttpResponse, except Exception as e: logger.error(f"Error streaming export: {str(e)}") return HttpResponseServerError("Error generating export") - + @method_decorator(csrf_exempt) @action(detail=False, methods=['get'], permission_classes=[AllowAny]) def export_direct(self, request: HttpRequest) -> Union[StreamingHttpResponse, JsonResponse]: - """Stream observations directly as CSV without using Celery.""" + """Stream observations directly as CSV with optimized memory usage.""" try: # Initialize the filterset with request parameters filterset = self.filterset_class( @@ -801,7 +936,7 @@ def export_direct(self, request: HttpRequest) -> Union[StreamingHttpResponse, Js # Get filtered queryset queryset = filterset.qs - # Check count + # Check count with a more efficient query total_count = queryset.count() if total_count > 10000: return JsonResponse({ @@ -816,14 +951,18 @@ def export_direct(self, request: HttpRequest) -> Union[StreamingHttpResponse, Js request.user.municipalities.values_list('id', flat=True) ) - # Create the streaming response with data from the task module - from .tasks.export_utils import generate_rows + # Create the streaming response pseudo_buffer = Echo() writer = csv.writer(pseudo_buffer) - # Stream response with appropriate headers response = StreamingHttpResponse( - generate_rows(queryset, writer, is_admin, user_municipality_ids), + streaming_content=generate_rows( + queryset=queryset, + writer=writer, + is_admin=is_admin, + user_municipality_ids=user_municipality_ids, + batch_size=200 # Smaller batch size for memory efficiency + ), content_type='text/csv' ) @@ -842,7 +981,7 @@ def export_direct(self, request: HttpRequest) -> Union[StreamingHttpResponse, Js except Exception as e: logger.exception("Export failed") return JsonResponse({"error": str(e)}, status=500) - + @require_GET def search_address(request: Request) -> JsonResponse: """