From cb786bb07f23af7792f05a940a50da08ee11d406 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Sun, 10 Jul 2016 09:57:41 -0400 Subject: [PATCH] refactored file-or-str opening. fixes #191 hobgoblins. --- mir_eval/io.py | 209 ++++++++++++++++--------------------- tests/test_input_output.py | 4 +- 2 files changed, 90 insertions(+), 123 deletions(-) diff --git a/mir_eval/io.py b/mir_eval/io.py index 9aaf4490..4e95b345 100644 --- a/mir_eval/io.py +++ b/mir_eval/io.py @@ -2,6 +2,7 @@ Functions for loading in annotations from files in different formats. """ +import contextlib import numpy as np import re import warnings @@ -12,6 +13,25 @@ from . import key +@contextlib.contextmanager +def _open(file_or_str, **kwargs): + '''Either open a file handle, or use an existing file-like object. + + This will behave as the `open` function if `file_or_str` is a string. + + If `file_or_str` has the `read` attribute, it will return `file_or_str`. + + Otherwise, an `IOError` is raised. + ''' + if hasattr(file_or_str, 'read'): + yield file_or_str + elif isinstance(file_or_str, six.string_types): + with open(file_or_str, **kwargs) as file_desc: + yield file_desc + else: + raise IOError('Invalid file-or-str object: {}'.format(file_or_str)) + + def load_delimited(filename, converters, delimiter=r'\s+'): r"""Utility function for loading in data from an annotation file where columns are delimited. The number of columns is inferred from the length of @@ -49,51 +69,33 @@ def load_delimited(filename, converters, delimiter=r'\s+'): # Create re object for splitting lines splitter = re.compile(delimiter) - # Keep track of whether we create our own file handle - own_fh = False - # If the filename input is a string, need to open it - if isinstance(filename, six.string_types): - # Remember that we need to close it later - own_fh = True - # Open the file for reading - input_file = open(filename, 'r') - # If the provided has a read attribute, we can use it as a file handle - elif hasattr(filename, 'read'): - input_file = filename - # Raise error otherwise - else: - raise ValueError('filename must be a string or file handle') - # Note: we do io manually here for two reasons. # 1. The csv module has difficulties with unicode, which may lead # to failures on certain annotation strings # # 2. numpy's text loader does not handle non-numeric data # - for row, line in enumerate(input_file, 1): - # Split each line using the supplied delimiter - data = splitter.split(line.strip(), n_columns - 1) - - # Throw a helpful error if we got an unexpected # of columns - if n_columns != len(data): - raise ValueError('Expected {} columns, got {} at ' - '{}:{:d}:\n\t{}'.format(n_columns, len(data), - filename, row, line)) - - for value, column, converter in zip(data, columns, converters): - # Try converting the value, throw a helpful error on failure - try: - converted_value = converter(value) - except: - raise ValueError("Couldn't convert value {} using {} " - "found at {}:{:d}:\n\t{}".format( - value, converter.__name__, filename, row, - line)) - column.append(converted_value) - - # Close the file handle if we opened it - if own_fh: - input_file.close() + with _open(filename, mode='r') as input_file: + for row, line in enumerate(input_file, 1): + # Split each line using the supplied delimiter + data = splitter.split(line.strip(), n_columns - 1) + + # Throw a helpful error if we got an unexpected # of columns + if n_columns != len(data): + raise ValueError('Expected {} columns, got {} at ' + '{}:{:d}:\n\t{}'.format(n_columns, len(data), + filename, row, line)) + + for value, column, converter in zip(data, columns, converters): + # Try converting the value, throw a helpful error on failure + try: + converted_value = converter(value) + except: + raise ValueError("Couldn't convert value {} using {} " + "found at {}:{:d}:\n\t{}".format( + value, converter.__name__, filename, + row, line)) + column.append(converted_value) # Sane output if n_columns == 1: @@ -313,54 +315,36 @@ def load_patterns(filename): """ - # Keep track of whether we create our own file handle - own_fh = False - # If the filename input is a string, need to open it - if isinstance(filename, six.string_types): - # Remember that we need to close it later - own_fh = True - # Open the file for reading - input_file = open(filename, 'r') - # If the provided has a read attribute, we can use it as a file handle - elif hasattr(filename, 'read'): - input_file = filename - # Raise error otherwise - else: - raise ValueError('filename must be a string or file handle') - # List with all the patterns pattern_list = [] # Current pattern, which will contain all occs pattern = [] # Current occurrence, containing (onset, midi) occurrence = [] - for line in input_file.readlines(): - if "pattern" in line: - if occurrence != []: - pattern.append(occurrence) - if pattern != []: - pattern_list.append(pattern) - occurrence = [] - pattern = [] - continue - if "occurrence" in line: - if occurrence != []: - pattern.append(occurrence) - occurrence = [] - continue - string_values = line.split(",") - onset_midi = (float(string_values[0]), float(string_values[1])) - occurrence.append(onset_midi) - - # Add last occurrence and pattern to pattern_list - if occurrence != []: - pattern.append(occurrence) - if pattern != []: - pattern_list.append(pattern) - - # If we opened an input file, we need to close it - if own_fh: - input_file.close() + with _open(filename, mode='r') as input_file: + for line in input_file.readlines(): + if "pattern" in line: + if occurrence != []: + pattern.append(occurrence) + if pattern != []: + pattern_list.append(pattern) + occurrence = [] + pattern = [] + continue + if "occurrence" in line: + if occurrence != []: + pattern.append(occurrence) + occurrence = [] + continue + string_values = line.split(",") + onset_midi = (float(string_values[0]), float(string_values[1])) + occurrence.append(onset_midi) + + # Add last occurrence and pattern to pattern_list + if occurrence != []: + pattern.append(occurrence) + if pattern != []: + pattern_list.append(pattern) return pattern_list @@ -523,49 +507,32 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+', # Create re object for splitting lines splitter = re.compile(delimiter) - # Keep track of whether we create our own file handle - own_fh = False - # If the filename input is a string, need to open it - if isinstance(filename, six.string_types): - # Remember that we need to close it later - own_fh = True - # Open the file for reading - input_file = open(filename, 'r') - # If the provided has a read attribute, we can use it as a file handle - elif hasattr(filename, 'read'): - input_file = filename - # Raise error otherwise - else: - raise ValueError('filename must be a string or file handle') if header: start_row = 1 else: start_row = 0 - for row, line in enumerate(input_file, start_row): - # Split each line using the supplied delimiter - data = splitter.split(line.strip()) - try: - converted_time = float(data[0]) - except (TypeError, ValueError) as exe: - six.raise_from(ValueError("Couldn't convert value {} using {} " - "found at {}:{:d}:\n\t{}".format( - data[0], float.__name__, - filename, row, line)), exe) - times.append(converted_time) - - # cast values to a numpy array. time stamps with no values are cast - # to an empty array. - try: - converted_value = np.array(data[1:], dtype=dtype) - except (TypeError, ValueError) as exe: - six.raise_from(ValueError("Couldn't convert value {} using {} " - "found at {}:{:d}:\n\t{}".format( - data[1:], dtype.__name__, - filename, row, line)), exe) - values.append(converted_value) - - # Close the file handle if we opened it - if own_fh: - input_file.close() + with _open(filename, mode='r') as input_file: + for row, line in enumerate(input_file, start_row): + # Split each line using the supplied delimiter + data = splitter.split(line.strip()) + try: + converted_time = float(data[0]) + except (TypeError, ValueError) as exe: + six.raise_from(ValueError("Couldn't convert value {} using {} " + "found at {}:{:d}:\n\t{}".format( + data[0], float.__name__, + filename, row, line)), exe) + times.append(converted_time) + + # cast values to a numpy array. time stamps with no values are cast + # to an empty array. + try: + converted_value = np.array(data[1:], dtype=dtype) + except (TypeError, ValueError) as exe: + six.raise_from(ValueError("Couldn't convert value {} using {} " + "found at {}:{:d}:\n\t{}".format( + data[1:], dtype.__name__, + filename, row, line)), exe) + values.append(converted_value) return np.array(times), values diff --git a/tests/test_input_output.py b/tests/test_input_output.py index 5d512619..87074d48 100644 --- a/tests/test_input_output.py +++ b/tests/test_input_output.py @@ -11,7 +11,7 @@ def test_load_delimited(): # Test for ValueError when a non-string or file handle is passed nose.tools.assert_raises( - ValueError, mir_eval.io.load_delimited, None, [int]) + IOError, mir_eval.io.load_delimited, None, [int]) # Test for a value error when the wrong number of columns is passed with tempfile.TemporaryFile('r+') as f: f.write('10 20') @@ -119,7 +119,7 @@ def test_load_valued_intervals(): def test_load_ragged_time_series(): # Test for ValueError when a non-string or file handle is passed nose.tools.assert_raises( - ValueError, mir_eval.io.load_ragged_time_series, None, float, + IOError, mir_eval.io.load_ragged_time_series, None, float, header=False) # Test for a value error on conversion failure with tempfile.TemporaryFile('r+') as f: