diff --git a/kive/container/ajax.py b/kive/container/ajax.py index 7323b0242..4378b6f52 100644 --- a/kive/container/ajax.py +++ b/kive/container/ajax.py @@ -223,6 +223,7 @@ def content_put(self, request, pk=None): response_data = dict(pipeline=['This field is required.']) else: container.write_content(content) + container.save() response_data = container.get_content() status_code = Response.status_code return Response(response_data, status_code) diff --git a/kive/container/models.py b/kive/container/models.py index ff696b963..3f308241b 100644 --- a/kive/container/models.py +++ b/kive/container/models.py @@ -12,7 +12,7 @@ from itertools import count from subprocess import STDOUT, CalledProcessError, check_output, check_call from tarfile import TarFile, TarInfo -from tempfile import mkdtemp, mkstemp, NamedTemporaryFile +from tempfile import mkdtemp, mkstemp import shutil import tarfile import io @@ -29,11 +29,10 @@ from django.dispatch import receiver from django.urls import reverse from django.utils import timezone -from django.forms.fields import FileField as FileFormField import django.utils.six as dsix from constants import maxlengths -from file_access_utils import compute_md5 +from file_access_utils import compute_md5, use_field_file from metadata.models import AccessControl, empty_removal_plan, remove_helper from stopwatch.models import Stopwatch import file_access_utils @@ -79,41 +78,6 @@ def remove(self): remove_helper(removal_plan) -class ContainerFileFormField(FileFormField): - def to_python(self, data): - """ Checks that the file-upload data contains a valid container. """ - f = super(ContainerFileFormField, self).to_python(data) - if f is None: - return None - - # We need to get a file object to validate. We might have a path or we might - # have to read the data out of memory. - if hasattr(data, 'temporary_file_path'): - Container.validate_container(data.temporary_file_path()) - else: - upload_name = getattr(data, 'name', 'container') - upload_base, upload_ext = os.path.splitext(upload_name) - with NamedTemporaryFile(prefix=upload_base, - suffix=upload_ext) as f_temp: - if hasattr(data, 'read'): - f_temp.write(data.read()) - else: - f_temp.write(data['content']) - f_temp.flush() - self.validate_container(f_temp.name) - - if hasattr(f, 'seek') and callable(f.seek): - f.seek(0) - return f - - -class ContainerFileField(models.FileField): - def formfield(self, **kwargs): - # noinspection PyTypeChecker - kwargs.setdefault('form_class', ContainerFileFormField) - return super(ContainerFileField, self).formfield(**kwargs) - - class ContainerNotChild(Exception): pass @@ -286,6 +250,11 @@ def validate_singularity_container(cls, file_path): raise ValidationError(cls.DEFAULT_ERROR_MESSAGES['invalid_singularity_container'], code='invalid_singularity_container') + def save(self, *args, **kwargs): + if not self.md5: + self.set_md5() + super(Container, self).save(*args, **kwargs) + def clean(self): """ Confirm that the file is of the correct type. @@ -300,7 +269,7 @@ def clean(self): # this step, we check for an "already validated" flag. if not getattr(self, "singularity_validated", False): fd, file_path = mkstemp() - with io.open(fd, mode="w+b") as f: + with use_field_file(self.file), io.open(fd, mode="w+b") as f: for chunk in self.file.chunks(): f.write(chunk) @@ -317,35 +286,23 @@ def clean(self): if self.file_type == Container.ZIP: try: - was_closed = self.file.closed - self.file.open() - try: + with use_field_file(self.file): with ZipFile(self.file): pass - finally: - if was_closed: - self.file.close() except BadZipfile: raise ValidationError(self.DEFAULT_ERROR_MESSAGES["invalid_archive"], code="invalid_archive") - else: # this is either a tarfile or a gzipped tar file + else: + assert self.file_type == Container.TAR try: - was_closed = self.file.closed - self.file.open() - try: + with use_field_file(self.file): with tarfile.open(fileobj=self.file, mode="r"): pass - finally: - if was_closed: - self.file.close() except tarfile.ReadError: raise ValidationError(self.DEFAULT_ERROR_MESSAGES["invalid_archive"], code="invalid_archive") - # Leave the file open and ready to go for whatever comes next in the processing. - self.file.open() # seeks to the 0 position if it's still open - def set_md5(self): """ Set this instance's md5 attribute. Note that this does not save the instance. @@ -353,9 +310,10 @@ def set_md5(self): This leaves self.file open and seek'd to the 0 position. :return: """ - self.file.open() # seeks to 0 if it was already open - self.md5 = compute_md5(self.file) - self.file.open() # leave it as we found it + if not self.file: + return + with use_field_file(self.file): + self.md5 = compute_md5(self.file) def validate_md5(self): """ @@ -412,9 +370,7 @@ def open_content(self, mode='r'): file_mode = 'rb+' else: raise ValueError('Unsupported mode for archive content: {!r}.'.format(mode)) - was_closed = self.file.closed - self.file.open(file_mode) - try: + with use_field_file(self.file, file_mode): if self.file_type == Container.ZIP: archive = ZipHandler(self.file, mode) elif self.file_type == Container.TAR: @@ -425,9 +381,6 @@ def open_content(self, mode='r'): self.file_type)) yield archive archive.close() - finally: - if was_closed: - self.file.close() def get_content(self, add_default=True): with self.open_content() as archive: @@ -460,6 +413,7 @@ def write_content(self, content): if file_name not in file_names: archive.write(file_name, pipeline_json) break + self.set_md5() self.create_app_from_content(content) def get_pipeline_state(self): diff --git a/kive/container/tests.py b/kive/container/tests.py index b424928ad..4134d8ecd 100644 --- a/kive/container/tests.py +++ b/kive/container/tests.py @@ -334,6 +334,8 @@ def test_create_content_and_app(self): app, = new_container.apps.all() self.assertEqual(expected_inputs, app.inputs) self.assertEqual(expected_outputs, app.outputs) + self.assertNotEqual('', new_container.md5) + self.assertIsNotNone(new_container.md5) def test_create_singularity_no_app(self): user = User.objects.first() @@ -504,6 +506,8 @@ def test_create_singularity(self): self.assertEquals(len(resp), start_count + 1) self.assertEquals(resp[0]['description'], expected_description) + self.assertNotEquals(resp[0]['md5'], '') + self.assertIsNotNone(resp[0]['md5']) def test_create_zip(self): expected_tag = "v1.0" @@ -546,6 +550,7 @@ def test_put_content(self): self.test_container.file.save( 'test.zip', ContentFile(self.create_zip_content().getvalue())) + old_md5 = self.test_container.md5 expected_content = dict(files=["bar.txt", "foo.txt"], pipeline=dict(default_config=dict(memory=400, threads=3), @@ -560,6 +565,10 @@ def test_put_content(self): content = self.content_view(request1, pk=self.detail_pk).data self.assertEqual(expected_content, content) + self.test_container.refresh_from_db() + new_md5 = self.test_container.md5 + self.assertNotEqual('', new_md5) + self.assertNotEqual(old_md5, new_md5) def test_put_bad_content(self): self.test_container.file_type = Container.ZIP diff --git a/kive/container/views.py b/kive/container/views.py index 80eef682c..84f670960 100644 --- a/kive/container/views.py +++ b/kive/container/views.py @@ -82,7 +82,6 @@ class ContainerCreate(CreateView, AdminViewMixin): def form_valid(self, form): form.instance.user = self.request.user form.instance.family = ContainerFamily.objects.get(pk=self.kwargs['family_id']) - form.instance.set_md5() response = super(ContainerCreate, self).form_valid(form) with transaction.atomic(): diff --git a/kive/file_access_utils.py b/kive/file_access_utils.py index 043d35daa..03b8e68a0 100755 --- a/kive/file_access_utils.py +++ b/kive/file_access_utils.py @@ -13,11 +13,11 @@ import stat import time import io +from contextlib import contextmanager from operator import itemgetter from django.conf import settings from django.utils import timezone -from django.core.files.base import File import django.utils.six as dsix from django.db import transaction @@ -203,6 +203,7 @@ def copyfile(src, dst, follow_symlinks=True): with python2.7, with the exception of the buffer size used in copying the file contents. """ + # noinspection PyUnresolvedReferences,PyProtectedMember if shutil._samefile(src, dst): raise SameFileError("{!r} and {!r} are the same file".format(src, dst)) @@ -443,6 +444,7 @@ def purge_unregistered_files(directory_to_scan, class_to_check, file_attr, bytes relative_path = os.path.relpath(absolute_path, settings.MEDIA_ROOT) all_files.append((absolute_path, mod_time, size, relative_path)) + # noinspection PyTypeChecker all_files = sorted(all_files, key=itemgetter(1)) bytes_purged = 0 @@ -468,3 +470,25 @@ def purge_unregistered_files(directory_to_scan, class_to_check, file_attr, bytes break return bytes_purged, files_purged, known_files, still_new + + +@contextmanager +def use_field_file(field_file, mode='rb'): + """ Context manager for FieldFile objects. + + Tries to leave a file object in the same state it was in when the context + manager started. + It's hard to tell when to close a FieldFile object. It opens implicitly + when you first read from it. Sometimes, it's an in-memory file object, and + it can't be reopened. + """ + was_closed = field_file.closed + field_file.open(mode) + start_position = field_file.tell() + try: + yield field_file + finally: + if was_closed: + field_file.close() + else: + field_file.seek(start_position)