Skip to content

Commit

Permalink
Update container MD5 when writing content, as part of #751.
Browse files Browse the repository at this point in the history
  • Loading branch information
donkirkby committed Feb 16, 2019
1 parent 8aa8c8d commit 301daa2
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 66 deletions.
1 change: 1 addition & 0 deletions kive/container/ajax.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def content_put(self, request, pk=None):
response_data = dict(pipeline=['This field is required.'])
else:
container.write_content(content)
container.save()
response_data = container.get_content()
status_code = Response.status_code
return Response(response_data, status_code)
Expand Down
82 changes: 18 additions & 64 deletions kive/container/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from itertools import count
from subprocess import STDOUT, CalledProcessError, check_output, check_call
from tarfile import TarFile, TarInfo
from tempfile import mkdtemp, mkstemp, NamedTemporaryFile
from tempfile import mkdtemp, mkstemp
import shutil
import tarfile
import io
Expand All @@ -29,11 +29,10 @@
from django.dispatch import receiver
from django.urls import reverse
from django.utils import timezone
from django.forms.fields import FileField as FileFormField
import django.utils.six as dsix

from constants import maxlengths
from file_access_utils import compute_md5
from file_access_utils import compute_md5, use_field_file
from metadata.models import AccessControl, empty_removal_plan, remove_helper
from stopwatch.models import Stopwatch
import file_access_utils
Expand Down Expand Up @@ -79,41 +78,6 @@ def remove(self):
remove_helper(removal_plan)


class ContainerFileFormField(FileFormField):
def to_python(self, data):
""" Checks that the file-upload data contains a valid container. """
f = super(ContainerFileFormField, self).to_python(data)
if f is None:
return None

# We need to get a file object to validate. We might have a path or we might
# have to read the data out of memory.
if hasattr(data, 'temporary_file_path'):
Container.validate_container(data.temporary_file_path())
else:
upload_name = getattr(data, 'name', 'container')
upload_base, upload_ext = os.path.splitext(upload_name)
with NamedTemporaryFile(prefix=upload_base,
suffix=upload_ext) as f_temp:
if hasattr(data, 'read'):
f_temp.write(data.read())
else:
f_temp.write(data['content'])
f_temp.flush()
self.validate_container(f_temp.name)

if hasattr(f, 'seek') and callable(f.seek):
f.seek(0)
return f


class ContainerFileField(models.FileField):
def formfield(self, **kwargs):
# noinspection PyTypeChecker
kwargs.setdefault('form_class', ContainerFileFormField)
return super(ContainerFileField, self).formfield(**kwargs)


class ContainerNotChild(Exception):
pass

Expand Down Expand Up @@ -286,6 +250,11 @@ def validate_singularity_container(cls, file_path):
raise ValidationError(cls.DEFAULT_ERROR_MESSAGES['invalid_singularity_container'],
code='invalid_singularity_container')

def save(self, *args, **kwargs):
if not self.md5:
self.set_md5()
super(Container, self).save(*args, **kwargs)

def clean(self):
"""
Confirm that the file is of the correct type.
Expand All @@ -300,7 +269,7 @@ def clean(self):
# this step, we check for an "already validated" flag.
if not getattr(self, "singularity_validated", False):
fd, file_path = mkstemp()
with io.open(fd, mode="w+b") as f:
with use_field_file(self.file), io.open(fd, mode="w+b") as f:
for chunk in self.file.chunks():
f.write(chunk)

Expand All @@ -317,45 +286,34 @@ def clean(self):

if self.file_type == Container.ZIP:
try:
was_closed = self.file.closed
self.file.open()
try:
with use_field_file(self.file):
with ZipFile(self.file):
pass
finally:
if was_closed:
self.file.close()
except BadZipfile:
raise ValidationError(self.DEFAULT_ERROR_MESSAGES["invalid_archive"],
code="invalid_archive")

else: # this is either a tarfile or a gzipped tar file
else:
assert self.file_type == Container.TAR
try:
was_closed = self.file.closed
self.file.open()
try:
with use_field_file(self.file):
with tarfile.open(fileobj=self.file, mode="r"):
pass
finally:
if was_closed:
self.file.close()
except tarfile.ReadError:
raise ValidationError(self.DEFAULT_ERROR_MESSAGES["invalid_archive"],
code="invalid_archive")

# Leave the file open and ready to go for whatever comes next in the processing.
self.file.open() # seeks to the 0 position if it's still open

def set_md5(self):
"""
Set this instance's md5 attribute. Note that this does not save the instance.
This leaves self.file open and seek'd to the 0 position.
:return:
"""
self.file.open() # seeks to 0 if it was already open
self.md5 = compute_md5(self.file)
self.file.open() # leave it as we found it
if not self.file:
return
with use_field_file(self.file):
self.md5 = compute_md5(self.file)

def validate_md5(self):
"""
Expand Down Expand Up @@ -412,9 +370,7 @@ def open_content(self, mode='r'):
file_mode = 'rb+'
else:
raise ValueError('Unsupported mode for archive content: {!r}.'.format(mode))
was_closed = self.file.closed
self.file.open(file_mode)
try:
with use_field_file(self.file, file_mode):
if self.file_type == Container.ZIP:
archive = ZipHandler(self.file, mode)
elif self.file_type == Container.TAR:
Expand All @@ -425,9 +381,6 @@ def open_content(self, mode='r'):
self.file_type))
yield archive
archive.close()
finally:
if was_closed:
self.file.close()

def get_content(self, add_default=True):
with self.open_content() as archive:
Expand Down Expand Up @@ -460,6 +413,7 @@ def write_content(self, content):
if file_name not in file_names:
archive.write(file_name, pipeline_json)
break
self.set_md5()
self.create_app_from_content(content)

def get_pipeline_state(self):
Expand Down
9 changes: 9 additions & 0 deletions kive/container/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def test_create_content_and_app(self):
app, = new_container.apps.all()
self.assertEqual(expected_inputs, app.inputs)
self.assertEqual(expected_outputs, app.outputs)
self.assertNotEqual('', new_container.md5)
self.assertIsNotNone(new_container.md5)

def test_create_singularity_no_app(self):
user = User.objects.first()
Expand Down Expand Up @@ -504,6 +506,8 @@ def test_create_singularity(self):

self.assertEquals(len(resp), start_count + 1)
self.assertEquals(resp[0]['description'], expected_description)
self.assertNotEquals(resp[0]['md5'], '')
self.assertIsNotNone(resp[0]['md5'])

def test_create_zip(self):
expected_tag = "v1.0"
Expand Down Expand Up @@ -546,6 +550,7 @@ def test_put_content(self):
self.test_container.file.save(
'test.zip',
ContentFile(self.create_zip_content().getvalue()))
old_md5 = self.test_container.md5
expected_content = dict(files=["bar.txt", "foo.txt"],
pipeline=dict(default_config=dict(memory=400,
threads=3),
Expand All @@ -560,6 +565,10 @@ def test_put_content(self):
content = self.content_view(request1, pk=self.detail_pk).data

self.assertEqual(expected_content, content)
self.test_container.refresh_from_db()
new_md5 = self.test_container.md5
self.assertNotEqual('', new_md5)
self.assertNotEqual(old_md5, new_md5)

def test_put_bad_content(self):
self.test_container.file_type = Container.ZIP
Expand Down
1 change: 0 additions & 1 deletion kive/container/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ class ContainerCreate(CreateView, AdminViewMixin):
def form_valid(self, form):
form.instance.user = self.request.user
form.instance.family = ContainerFamily.objects.get(pk=self.kwargs['family_id'])
form.instance.set_md5()

response = super(ContainerCreate, self).form_valid(form)
with transaction.atomic():
Expand Down
26 changes: 25 additions & 1 deletion kive/file_access_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import stat
import time
import io
from contextlib import contextmanager
from operator import itemgetter

from django.conf import settings
from django.utils import timezone
from django.core.files.base import File
import django.utils.six as dsix
from django.db import transaction

Expand Down Expand Up @@ -203,6 +203,7 @@ def copyfile(src, dst, follow_symlinks=True):
with python2.7, with the exception of the buffer size
used in copying the file contents.
"""
# noinspection PyUnresolvedReferences,PyProtectedMember
if shutil._samefile(src, dst):
raise SameFileError("{!r} and {!r} are the same file".format(src, dst))

Expand Down Expand Up @@ -443,6 +444,7 @@ def purge_unregistered_files(directory_to_scan, class_to_check, file_attr, bytes
relative_path = os.path.relpath(absolute_path, settings.MEDIA_ROOT)
all_files.append((absolute_path, mod_time, size, relative_path))

# noinspection PyTypeChecker
all_files = sorted(all_files, key=itemgetter(1))

bytes_purged = 0
Expand All @@ -468,3 +470,25 @@ def purge_unregistered_files(directory_to_scan, class_to_check, file_attr, bytes
break

return bytes_purged, files_purged, known_files, still_new


@contextmanager
def use_field_file(field_file, mode='rb'):
""" Context manager for FieldFile objects.
Tries to leave a file object in the same state it was in when the context
manager started.
It's hard to tell when to close a FieldFile object. It opens implicitly
when you first read from it. Sometimes, it's an in-memory file object, and
it can't be reopened.
"""
was_closed = field_file.closed
field_file.open(mode)
start_position = field_file.tell()
try:
yield field_file
finally:
if was_closed:
field_file.close()
else:
field_file.seek(start_position)

0 comments on commit 301daa2

Please sign in to comment.