Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored file-or-str opening. fixes #191 #208

Merged
merged 1 commit into from
Jul 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 88 additions & 121 deletions mir_eval/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions for loading in annotations from files in different formats.
"""

import contextlib
import numpy as np
import re
import warnings
Expand All @@ -12,6 +13,25 @@
from . import key


@contextlib.contextmanager
def _open(file_or_str, **kwargs):
'''Either open a file handle, or use an existing file-like object.

This will behave as the `open` function if `file_or_str` is a string.

If `file_or_str` has the `read` attribute, it will return `file_or_str`.

Otherwise, an `IOError` is raised.
'''
if hasattr(file_or_str, 'read'):
yield file_or_str
elif isinstance(file_or_str, six.string_types):
with open(file_or_str, **kwargs) as file_desc:
yield file_desc
else:
raise IOError('Invalid file-or-str object: {}'.format(file_or_str))


def load_delimited(filename, converters, delimiter=r'\s+'):
r"""Utility function for loading in data from an annotation file where columns
are delimited. The number of columns is inferred from the length of
Expand Down Expand Up @@ -49,51 +69,33 @@ def load_delimited(filename, converters, delimiter=r'\s+'):
# Create re object for splitting lines
splitter = re.compile(delimiter)

# Keep track of whether we create our own file handle
own_fh = False
# If the filename input is a string, need to open it
if isinstance(filename, six.string_types):
# Remember that we need to close it later
own_fh = True
# Open the file for reading
input_file = open(filename, 'r')
# If the provided has a read attribute, we can use it as a file handle
elif hasattr(filename, 'read'):
input_file = filename
# Raise error otherwise
else:
raise ValueError('filename must be a string or file handle')

# Note: we do io manually here for two reasons.
# 1. The csv module has difficulties with unicode, which may lead
# to failures on certain annotation strings
#
# 2. numpy's text loader does not handle non-numeric data
#
for row, line in enumerate(input_file, 1):
# Split each line using the supplied delimiter
data = splitter.split(line.strip(), n_columns - 1)

# Throw a helpful error if we got an unexpected # of columns
if n_columns != len(data):
raise ValueError('Expected {} columns, got {} at '
'{}:{:d}:\n\t{}'.format(n_columns, len(data),
filename, row, line))

for value, column, converter in zip(data, columns, converters):
# Try converting the value, throw a helpful error on failure
try:
converted_value = converter(value)
except:
raise ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
value, converter.__name__, filename, row,
line))
column.append(converted_value)

# Close the file handle if we opened it
if own_fh:
input_file.close()
with _open(filename, mode='r') as input_file:
for row, line in enumerate(input_file, 1):
# Split each line using the supplied delimiter
data = splitter.split(line.strip(), n_columns - 1)

# Throw a helpful error if we got an unexpected # of columns
if n_columns != len(data):
raise ValueError('Expected {} columns, got {} at '
'{}:{:d}:\n\t{}'.format(n_columns, len(data),
filename, row, line))

for value, column, converter in zip(data, columns, converters):
# Try converting the value, throw a helpful error on failure
try:
converted_value = converter(value)
except:
raise ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
value, converter.__name__, filename,
row, line))
column.append(converted_value)

# Sane output
if n_columns == 1:
Expand Down Expand Up @@ -313,54 +315,36 @@ def load_patterns(filename):

"""

# Keep track of whether we create our own file handle
own_fh = False
# If the filename input is a string, need to open it
if isinstance(filename, six.string_types):
# Remember that we need to close it later
own_fh = True
# Open the file for reading
input_file = open(filename, 'r')
# If the provided has a read attribute, we can use it as a file handle
elif hasattr(filename, 'read'):
input_file = filename
# Raise error otherwise
else:
raise ValueError('filename must be a string or file handle')

# List with all the patterns
pattern_list = []
# Current pattern, which will contain all occs
pattern = []
# Current occurrence, containing (onset, midi)
occurrence = []
for line in input_file.readlines():
if "pattern" in line:
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)
occurrence = []
pattern = []
continue
if "occurrence" in line:
if occurrence != []:
pattern.append(occurrence)
occurrence = []
continue
string_values = line.split(",")
onset_midi = (float(string_values[0]), float(string_values[1]))
occurrence.append(onset_midi)

# Add last occurrence and pattern to pattern_list
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)

# If we opened an input file, we need to close it
if own_fh:
input_file.close()
with _open(filename, mode='r') as input_file:
for line in input_file.readlines():
if "pattern" in line:
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)
occurrence = []
pattern = []
continue
if "occurrence" in line:
if occurrence != []:
pattern.append(occurrence)
occurrence = []
continue
string_values = line.split(",")
onset_midi = (float(string_values[0]), float(string_values[1]))
occurrence.append(onset_midi)

# Add last occurrence and pattern to pattern_list
if occurrence != []:
pattern.append(occurrence)
if pattern != []:
pattern_list.append(pattern)

return pattern_list

Expand Down Expand Up @@ -523,49 +507,32 @@ def load_ragged_time_series(filename, dtype=float, delimiter=r'\s+',
# Create re object for splitting lines
splitter = re.compile(delimiter)

# Keep track of whether we create our own file handle
own_fh = False
# If the filename input is a string, need to open it
if isinstance(filename, six.string_types):
# Remember that we need to close it later
own_fh = True
# Open the file for reading
input_file = open(filename, 'r')
# If the provided has a read attribute, we can use it as a file handle
elif hasattr(filename, 'read'):
input_file = filename
# Raise error otherwise
else:
raise ValueError('filename must be a string or file handle')
if header:
start_row = 1
else:
start_row = 0
for row, line in enumerate(input_file, start_row):
# Split each line using the supplied delimiter
data = splitter.split(line.strip())
try:
converted_time = float(data[0])
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[0], float.__name__,
filename, row, line)), exe)
times.append(converted_time)

# cast values to a numpy array. time stamps with no values are cast
# to an empty array.
try:
converted_value = np.array(data[1:], dtype=dtype)
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[1:], dtype.__name__,
filename, row, line)), exe)
values.append(converted_value)

# Close the file handle if we opened it
if own_fh:
input_file.close()
with _open(filename, mode='r') as input_file:
for row, line in enumerate(input_file, start_row):
# Split each line using the supplied delimiter
data = splitter.split(line.strip())
try:
converted_time = float(data[0])
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[0], float.__name__,
filename, row, line)), exe)
times.append(converted_time)

# cast values to a numpy array. time stamps with no values are cast
# to an empty array.
try:
converted_value = np.array(data[1:], dtype=dtype)
except (TypeError, ValueError) as exe:
six.raise_from(ValueError("Couldn't convert value {} using {} "
"found at {}:{:d}:\n\t{}".format(
data[1:], dtype.__name__,
filename, row, line)), exe)
values.append(converted_value)

return np.array(times), values
4 changes: 2 additions & 2 deletions tests/test_input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def test_load_delimited():
# Test for ValueError when a non-string or file handle is passed
nose.tools.assert_raises(
ValueError, mir_eval.io.load_delimited, None, [int])
IOError, mir_eval.io.load_delimited, None, [int])
# Test for a value error when the wrong number of columns is passed
with tempfile.TemporaryFile('r+') as f:
f.write('10 20')
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_load_valued_intervals():
def test_load_ragged_time_series():
# Test for ValueError when a non-string or file handle is passed
nose.tools.assert_raises(
ValueError, mir_eval.io.load_ragged_time_series, None, float,
IOError, mir_eval.io.load_ragged_time_series, None, float,
header=False)
# Test for a value error on conversion failure
with tempfile.TemporaryFile('r+') as f:
Expand Down