Skip to content

Commit

Permalink
refactored file-or-str opening. fixes #191
Browse files Browse the repository at this point in the history
hobgoblins.
  • Loading branch information
bmcfee committed Jul 15, 2016
1 parent 5887679 commit cb786bb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 123 deletions.
209 changes: 88 additions & 121 deletions mir_eval/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions for loading in annotations from files in different formats.
"""

import contextlib
import numpy as np
import re
import warnings
Expand All @@ -12,6 +13,25 @@
from . import key


@contextlib.contextmanager
def _open(file_or_str, **kwargs):
'''Either open a file handle, or use an existing file-like object.
This will behave as the `open` function if `file_or_str` is a string.
If `file_or_str` has the `read` attribute, it will return `file_or_str`.
Otherwise, an `IOError` is raised.
'''
if hasattr(file_or_str, 'read'):
yield file_or_str
elif isinstance(file_or_str, six.string_types):
with open(file_or_str, **kwargs) as file_desc:
yield file_desc
else:
raise IOError('Invalid file-or-str object: {}'.format(file_or_str))


def load_delimited(filename, converters, delimiter=r'\s+'):
r"""Utility function for loading in data from an annotation file where columns
are delimited. The number of columns is inferred from the length of
Expand Down Expand Up @@ -49,51 +69,33 @@ def load_delimited(filename, converters, delimiter=r'\s+'):
# Create re object for splitting lines
splitter = re.compile(delimiter)

# Keep track of whether we create our own file handle
own_fh = False
# If the filename input is a string, need to open it
if isinstance(filename, six.string_types):
# Remember that we need to close it later
own_fh = True
# Open the file for reading
input_file = open(filename, 'r')
# If the provided has a read attribute, we can use it as a file handle
elif hasattr(filename, 'read'):
input_file = filename
# Raise error otherwise
else:
raise ValueError('filename must be a string or file handle')

# Note: we do io manually here for two reasons.
# 1. The csv module has difficulties with unicode, which may lead
# to failures on certain annotation strings
#
# 2. numpy's text loader does not handle non-numeric data
#
for row, line in enumerate(input_file, 1):
# Split each line using the supplied delimiter
data = splitter.split(line.strip(), n_columns - 1)

# Throw a helpful error if we got an unexpected # of columns
if n_columns != len(data):
raise ValueError('Expected {} columns, got {} at '
'{}:{:d}:\n\t{}'.format(n_columns, len(data),
filename, row, line))

for value, column, converter in zip(data, columns, converters):
# Try converting the value, throw a helpful error on failure
try:
converted_value = converter(value)
except:
raise ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
value, converter.__name__, filename, row,
line))
column.append(converted_value)

# Close the file handle if we opened it
if own_fh:
input_file.close()
with _open(filename, mode='r') as input_file:
for row, line in enumerate(input_file, 1):
# Split each line using the supplied delimiter
data = splitter.split(line.strip(), n_columns - 1)

# Throw a helpful error if we got an unexpected # of columns
if n_columns != len(data):
raise ValueError('Expected {} columns, got {} at '
'{}:{:d}:\n\t{}'.format(n_columns, len(data),
filename, row, line))

for value, column, converter in zip(data, columns, converters):
# Try converting the value, throw a helpful error on failure
try:
converted_value = converter(value)
except:
raise ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
value, converter.__name__, filename,
row, line))
column.append(converted_value)

# Sane output
if n_columns == 1:
Expand Down Expand Up @@ -313,54 +315,36 @@ def load_patterns(filename):
"""

# Keep track of whether we create our own file handle
own_fh = False
# If the filename input is a string, need to open it
if isinstance(filename, six.string_types):
# Remember that we need to close it later
own_fh = True
# Open the file for reading
input_file = open(filename, 'r')
# If the provided has a read attribute, we can use it as a file handle
elif hasattr(filename, 'read'):
input_file = filename
# Raise error otherwise
else:
raise ValueError('filename must be a string or file handle')

# List with all the patterns
pattern_list = []
# Current pattern, which will contain all occs
pattern = []
# Current occurrence, containing (onset, midi)
occurrence = []
for line in input_file.readlines():
if "pattern" in line:
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)
occurrence = []
pattern = []
continue
if "occurrence" in line:
if occurrence != []:
pattern.append(occurrence)
occurrence = []
continue
string_values = line.split(",")
onset_midi = (float(string_values[0]), float(string_values[1]))
occurrence.append(onset_midi)

# Add last occurrence and pattern to pattern_list
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)

# If we opened an input file, we need to close it
if own_fh:
input_file.close()
with _open(filename, mode='r') as input_file:
for line in input_file.readlines():
if "pattern" in line:
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)
occurrence = []
pattern = []
continue
if "occurrence" in line:
if occurrence != []:
pattern.append(occurrence)
occurrence = []
continue
string_values = line.split(",")
onset_midi = (float(string_values[0]), float(string_values[1]))
occurrence.append(onset_midi)

# Add last occurrence and pattern to pattern_list
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)

return pattern_list

Expand Down Expand Up @@ -523,49 +507,32 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+',
# Create re object for splitting lines
splitter = re.compile(delimiter)

# Keep track of whether we create our own file handle
own_fh = False
# If the filename input is a string, need to open it
if isinstance(filename, six.string_types):
# Remember that we need to close it later
own_fh = True
# Open the file for reading
input_file = open(filename, 'r')
# If the provided has a read attribute, we can use it as a file handle
elif hasattr(filename, 'read'):
input_file = filename
# Raise error otherwise
else:
raise ValueError('filename must be a string or file handle')
if header:
start_row = 1
else:
start_row = 0
for row, line in enumerate(input_file, start_row):
# Split each line using the supplied delimiter
data = splitter.split(line.strip())
try:
converted_time = float(data[0])
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[0], float.__name__,
filename, row, line)), exe)
times.append(converted_time)

# cast values to a numpy array. time stamps with no values are cast
# to an empty array.
try:
converted_value = np.array(data[1:], dtype=dtype)
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[1:], dtype.__name__,
filename, row, line)), exe)
values.append(converted_value)

# Close the file handle if we opened it
if own_fh:
input_file.close()
with _open(filename, mode='r') as input_file:
for row, line in enumerate(input_file, start_row):
# Split each line using the supplied delimiter
data = splitter.split(line.strip())
try:
converted_time = float(data[0])
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[0], float.__name__,
filename, row, line)), exe)
times.append(converted_time)

# cast values to a numpy array. time stamps with no values are cast
# to an empty array.
try:
converted_value = np.array(data[1:], dtype=dtype)
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[1:], dtype.__name__,
filename, row, line)), exe)
values.append(converted_value)

return np.array(times), values
4 changes: 2 additions & 2 deletions tests/test_input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def test_load_delimited():
# Test for ValueError when a non-string or file handle is passed
nose.tools.assert_raises(
ValueError, mir_eval.io.load_delimited, None, [int])
IOError, mir_eval.io.load_delimited, None, [int])
# Test for a value error when the wrong number of columns is passed
with tempfile.TemporaryFile('r+') as f:
f.write('10 20')
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_load_valued_intervals():
def test_load_ragged_time_series():
# Test for ValueError when a non-string or file handle is passed
nose.tools.assert_raises(
ValueError, mir_eval.io.load_ragged_time_series, None, float,
IOError, mir_eval.io.load_ragged_time_series, None, float,
header=False)
# Test for a value error on conversion failure
with tempfile.TemporaryFile('r+') as f:
Expand Down

0 comments on commit cb786bb

Please sign in to comment.