Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hanging Ctrl+C on S3 downloads #673

Merged
merged 4 commits into from
Feb 25, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 41 additions & 53 deletions awscli/customizations/s3/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import logging
from six.moves import queue as Queue
from six.moves import queue
import sys
import threading

from awscli.customizations.s3.utils import NoBlockQueue, uni_print, \
from awscli.customizations.s3.utils import uni_print, \
IORequest, IOCloseRequest


Expand All @@ -26,22 +26,22 @@
class Executor(object):
"""
This class is in charge of all of the threads. It starts up the threads
and cleans up the threads when done. The two type of threads the
and cleans up the threads when finished. The two type of threads the
``Executor``runs is a worker and a print thread.
"""
def __init__(self, done, num_threads, result_queue,
quiet, interrupt, max_queue_size, write_queue):
self.queue = None
self.done = done
def __init__(self, num_threads, result_queue,
quiet, max_queue_size, write_queue):
self._max_queue_size = max_queue_size
self.queue = queue.Queue(maxsize=self._max_queue_size)
self.num_threads = num_threads
self.result_queue = result_queue
self.quiet = quiet
self.interrupt = interrupt
self.threads_list = []
self._max_queue_size = max_queue_size
self.write_queue = write_queue
self.print_thread = None
self.io_thread = None
self.print_thread = PrintThread(self.result_queue,
self.quiet)
self.print_thread.daemon = True
self.io_thread = IOWriterThread(self.write_queue)

@property
def num_tasks_failed(self):
Expand All @@ -51,17 +51,14 @@ def num_tasks_failed(self):
return tasks_failed

def start(self):
self.print_thread = PrintThread(self.result_queue, self.done,
self.quiet, self.interrupt)
self.print_thread.daemon = True
self.io_thread = IOWriterThread(self.write_queue, self.done)
self.io_thread.start()
self.threads_list.append(self.io_thread)
self.queue = NoBlockQueue(self.interrupt, maxsize=self._max_queue_size)
self.threads_list.append(self.print_thread)
# Note that we're *not* adding the IO thread to the threads_list.
# There's a specific shutdown order we need and we're going to be
# explicit about it rather than relying on the threads_list order.
# See .join() for more info.
self.print_thread.start()
for i in range(self.num_threads):
worker = Worker(queue=self.queue, done=self.done)
worker = Worker(queue=self.queue)
worker.setDaemon(True)
self.threads_list.append(worker)
worker.start()
Expand All @@ -73,31 +70,35 @@ def submit(self, task):
LOGGER.debug("Submitting task: %s", task)
self.queue.put(task)

def wait(self):
"""
This is the function used to wait on all of the tasks to finish
in the ``Executor``.
"""
self.queue.join()

def join(self):
"""
This is used to clean up the ``Executor``.
"""
self.write_queue.put(QUEUE_END_SENTINEL)
self.result_queue.put(QUEUE_END_SENTINEL)
for i in range(self.num_threads):
LOGGER.debug("Queueing end sentinel for worker thread.")
self.queue.put(QUEUE_END_SENTINEL)

for thread in self.threads_list:
LOGGER.debug("Waiting for thread to shutdown: %s", thread)
thread.join()
LOGGER.debug("Thread has been shutdown: %s", thread)

LOGGER.debug("Queueing end sentinel for result thread.")
self.result_queue.put(QUEUE_END_SENTINEL)

LOGGER.debug("Queueing end sentinel for IO thread.")
self.write_queue.put(QUEUE_END_SENTINEL)
LOGGER.debug("Waiting for result thread to shutdown.")
self.print_thread.join()
LOGGER.debug("Waiting for IO thread to shutdown.")
self.io_thread.join()
LOGGER.debug("All threads have been shutdown.")


class IOWriterThread(threading.Thread):
def __init__(self, queue, done):
def __init__(self, queue):
threading.Thread.__init__(self)
self.queue = queue
self.done = done
self.fd_descriptor_cache = {}

def run(self):
Expand Down Expand Up @@ -137,28 +138,26 @@ class Worker(threading.Thread):
This thread is in charge of performing the tasks provided via
the main queue ``queue``.
"""
def __init__(self, queue, done):
def __init__(self, queue):
threading.Thread.__init__(self)
# This is the queue where work (tasks) are submitted.
self.queue = queue
self.done = done

def run(self):
while True:
try:
function = self.queue.get(True)
if function is QUEUE_END_SENTINEL:
self.queue.task_done()
LOGGER.debug("End sentinel received in worker thread, "
"shutting down worker thread.")
break
try:
LOGGER.debug("Worker thread invoking task: %s", function)
function()
except Exception as e:
LOGGER.debug('Error calling task: %s', e, exc_info=True)
self.queue.task_done()
except Queue.Empty:
except queue.Empty:
pass
if self.done.isSet():
break


class PrintThread(threading.Thread):
Expand All @@ -181,12 +180,10 @@ class PrintThread(threading.Thread):
deprecated, will be removed in the future).

"""
def __init__(self, result_queue, done, quiet, interrupt):
def __init__(self, result_queue, quiet):
threading.Thread.__init__(self)
self._progress_dict = {}
self._interrupt = interrupt
self._result_queue = result_queue
self._done = done
self._quiet = quiet
self._progress_length = 0
self._num_parts = 0
Expand Down Expand Up @@ -214,26 +211,17 @@ def run(self):
try:
print_task = self._result_queue.get(True)
if print_task is QUEUE_END_SENTINEL:
self._result_queue.task_done()
if self._needs_newline:
sys.stdout.write('\n')
break
LOGGER.debug("Received print task: %s", print_task)
try:
self._process_print_task(print_task)
except Exception as e:
LOGGER.debug("Error processing print task: %s", e,
exc_info=True)
finally:
# Because the shutdown logic requires that the print
# queue finish, we need to have all the print tasks
# finished, even if an exception happens trying to print
# them.
self._result_queue.task_done()
except Queue.Empty:
except queue.Empty:
pass
if self._done.isSet():
if self._needs_newline:
sys.stdout.write('\n')
break

def _process_print_task(self, print_task):
print_str = print_task['message']
Expand Down Expand Up @@ -265,7 +253,7 @@ def _process_print_task(self, print_task):
self._file_count += 1

is_done = self._total_files == self._file_count
if not self._interrupt.isSet() and not is_done:
if not is_done:
prog_str = "Completed %s " % self._num_parts
num_files = self._total_files
if self._total_files != '...':
Expand Down
29 changes: 9 additions & 20 deletions awscli/customizations/s3/s3handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import logging
import math
import os
import threading
from six.moves import queue

from awscli.customizations.s3.constants import MULTI_THRESHOLD, CHUNKSIZE, \
NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE
from awscli.customizations.s3.utils import NoBlockQueue, find_chunksize, \
from awscli.customizations.s3.utils import find_chunksize, \
operate, find_bucket_key, relative_path
from awscli.customizations.s3.executor import Executor
from awscli.customizations.s3 import tasks
Expand All @@ -36,14 +36,11 @@ class pull tasks from to complete.
def __init__(self, session, params, multi_threshold=MULTI_THRESHOLD,
chunksize=CHUNKSIZE):
self.session = session
self.done = threading.Event()
self.interrupt = threading.Event()
self.result_queue = NoBlockQueue()
self.result_queue = queue.Queue()
# The write_queue has potential for optimizations, so the constant
# for maxsize is scoped to this class (as opposed to constants.py)
# so we have the ability to change this value later.
self.write_queue = NoBlockQueue(self.interrupt,
maxsize=self.MAX_IO_QUEUE_SIZE)
self.write_queue = queue.Queue(maxsize=self.MAX_IO_QUEUE_SIZE)
self.params = {'dryrun': False, 'quiet': False, 'acl': None,
'guess_mime_type': True, 'sse': False,
'storage_class': None, 'website_redirect': None,
Expand All @@ -58,9 +55,9 @@ def __init__(self, session, params, multi_threshold=MULTI_THRESHOLD,
self.multi_threshold = multi_threshold
self.chunksize = chunksize
self.executor = Executor(
done=self.done, num_threads=NUM_THREADS, result_queue=self.result_queue,
quiet=self.params['quiet'], interrupt=self.interrupt,
max_queue_size=MAX_QUEUE_SIZE, write_queue=self.write_queue
num_threads=NUM_THREADS, result_queue=self.result_queue,
quiet=self.params['quiet'], max_queue_size=MAX_QUEUE_SIZE,
write_queue=self.write_queue
)
self._multipart_uploads = []
self._multipart_downloads = []
Expand All @@ -74,31 +71,23 @@ def call(self, files):
essentially a thread of execution for a thread to follow. These
tasks are then submitted to the main executor.
"""
self.done.clear()
self.interrupt.clear()
try:
self.executor.start()
total_files, total_parts = self._enqueue_tasks(files)
self.executor.print_thread.set_total_files(total_files)
self.executor.print_thread.set_total_parts(total_parts)
self.executor.wait()
self.result_queue.join()

except Exception as e:
LOGGER.debug('Exception caught during task execution: %s',
str(e), exc_info=True)
self.result_queue.put({'message': str(e), 'error': True})
except KeyboardInterrupt:
self.interrupt.set()
self.result_queue.put({'message': "Cleaning up. Please wait...",
'error': False})
'error': True})
self._shutdown()
return self.executor.num_tasks_failed

def _shutdown(self):
# self.done will tell threads to shutdown.
self.done.set()
# This waill wait until all the threads are joined.
# This will wait until all the threads are joined.
self.executor.join()
# And finally we need to make a pass through all the existing
# multipart uploads and abort any pending multipart uploads.
Expand Down
28 changes: 21 additions & 7 deletions awscli/customizations/s3/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,24 @@ def __init__(self, context, filename):

def __call__(self):
dirname = os.path.dirname(self._filename.dest)
if not os.path.isdir(dirname):
os.makedirs(dirname)
# Always create the file. Even if it exists, we need to
# wipe out the existing contents.
with open(self._filename.dest, 'wb'):
pass
self._context.announce_file_created()
try:
if not os.path.isdir(dirname):
try:
os.makedirs(dirname)
except OSError:
# It's possible that between the if check and the makedirs
# check that another thread has come along and created the
# directory. In this case the directory already exists and we
# can move on.
pass
# Always create the file. Even if it exists, we need to
# wipe out the existing contents.
with open(self._filename.dest, 'wb'):
pass
except Exception as e:
self._context.cancel()
else:
self._context.announce_file_created()


class CompleteDownloadTask(object):
Expand Down Expand Up @@ -355,6 +366,7 @@ def _download_part(self):
result = {'message': message, 'error': False,
'total_parts': total_parts}
self._result_queue.put(result)
LOGGER.debug("Task complete: %s", self)
return
except (socket.timeout, socket.error) as e:
LOGGER.debug("Socket timeout caught, retrying request, "
Expand All @@ -378,7 +390,9 @@ def _queue_writes(self, body):
current = body.read(iterate_chunk_size)
while current:
offset = self._part_number * self._chunk_size + amount_read
LOGGER.debug("Submitting IORequest to write queue.")
self._io_queue.put(IORequest(self._filename.dest, offset, current))
LOGGER.debug("Request successfully submitted.")
amount_read += len(current)
current = body.read(iterate_chunk_size)
# Change log message.
Expand Down
27 changes: 2 additions & 25 deletions awscli/customizations/s3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from functools import partial

from six import PY3
from six.moves import queue as Queue
from dateutil.tz import tzlocal

from awscli.customizations.s3.constants import QUEUE_TIMEOUT_WAIT, \
MAX_PARTS, MAX_SINGLE_UPLOAD_SIZE
from awscli.customizations.s3.constants import MAX_PARTS
from awscli.customizations.s3.constants import MAX_SINGLE_UPLOAD_SIZE


class MD5Error(Exception):
Expand All @@ -34,28 +33,6 @@ class MD5Error(Exception):
pass


class NoBlockQueue(Queue.Queue):
"""
This queue ensures that joining does not block interrupt signals.
It also contains a threading event ``interrupt`` that breaks the
while loop if signaled. The ``interrupt`` signal is optional.
If left out, this should act like a normal queue.
"""
def __init__(self, interrupt=None, maxsize=0):
Queue.Queue.__init__(self, maxsize=maxsize)
self.interrupt = interrupt

def join(self):
self.all_tasks_done.acquire()
try:
while self.unfinished_tasks:
if self.interrupt and self.interrupt.isSet():
break
self.all_tasks_done.wait(QUEUE_TIMEOUT_WAIT)
finally:
self.all_tasks_done.release()


def find_bucket_key(s3_path):
"""
This is a helper function that given an s3 path such that the path is of
Expand Down
Loading