Skip to content

Commit

Permalink
add explicit interpolation: interp(); fix tests to use new API
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasha committed Aug 31, 2017
1 parent 3faae02 commit b41cd38
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions audio_analysis/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import bokeh.plotting
import bokeh.models
import pandas
import scipy.interpolate


def new_data(x, xlabel, y, ylabel):
Expand All @@ -24,7 +25,7 @@ def plot(self, fig=None, legend=None, color=None):
sizing_mode='stretch_both')
# increase the limit for sci. notation on x-axis
fig.xaxis.formatter = bokeh.models.BasicTickFormatter(power_limit_high=6)
if self._x2_factor is not None:
if hasattr(self, '_x2_factor'):
x2_ax = bokeh.models.LinearAxis(axis_label=self._x2_label)
js_code = "return Math.round(tick * %f * 100) / 100" % self._x2_factor
x2_ax.formatter = bokeh.models.FuncTickFormatter(code=js_code)
Expand All @@ -33,10 +34,13 @@ def plot(self, fig=None, legend=None, color=None):
return fig

def add_x_axis(self, factor, label):
assert self._x2_factor is None, 'only 2 x-axis possible'
assert not hasattr(self, '_x2_factor'), 'only 2 x-axis possible'
self._x2_factor = factor
self._x2_label = label

def interp(self, x, kind='linear'):
return scipy.interpolate.interp1d(self.index, self, kind)(x)


def _dbfs(sq):
return 10 * np.log10(sq * 2) # +3dB, dBFS sine scaling
Expand Down
6 changes: 3 additions & 3 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_spectrum(file):
ch = 0
pcm = audio.PCMArray(os.path.join(here, file))[ch]
s = tools.spectrum(pcm)
assert abs(s[hz] - dbfs) < TOLERANCE_DB
assert abs(s.interp(hz) - dbfs) < TOLERANCE_DB


@pytest.mark.parametrize('file', FILES.keys())
Expand All @@ -44,6 +44,6 @@ def test_power(file):
ch = 0
pcm = audio.PCMArray(os.path.join(here, file))[ch]
p = tools.power(pcm, window=1023)
assert all(abs(p.y - dbfs) < TOLERANCE_DB + .1)
assert abs(p.y.mean() - dbfs) < TOLERANCE_DB
assert all(abs(p - dbfs) < TOLERANCE_DB + .1)
assert abs(p.mean() - dbfs) < TOLERANCE_DB

0 comments on commit b41cd38

Please sign in to comment.