Skip to content

Commit

Permalink
Merge pull request #3795 from lscheinkman/1380
Browse files Browse the repository at this point in the history
numenta/nupic.core-legacy#1380: Fix SP tests with correct dtype values
  • Loading branch information
lscheinkman authored Jan 16, 2018
2 parents 17db904 + 01a2715 commit af7735f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ prettytable==0.7.2

# When updating nupic.bindings, also update any shared dependencies to keep
# versions in sync.
nupic.bindings==1.0.0
nupic.bindings==1.0.3
numpy==1.12.1
6 changes: 3 additions & 3 deletions tests/unit/nupic/algorithms/sp_overlap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def frequency(self,
maxval=maxVal, periodic=False, forced=True) # forced: it's strongly recommended to use w>=21, in the example we force skip the check for readibility
for y in xrange(numColors):
temp = enc.encode(rnd.random()*maxVal)
colors.append(numpy.array(temp, dtype=realDType))
colors.append(numpy.array(temp, dtype=numpy.uint32))
else:
for y in xrange(numColors):
sdr = numpy.zeros(n, dtype=realDType)
sdr = numpy.zeros(n, dtype=numpy.uint32)
# Randomly setting w out of n bits to 1
sdr[rnd.sample(xrange(n), w)] = 1
colors.append(sdr)
Expand All @@ -144,7 +144,7 @@ def frequency(self,
for i in xrange(numColors):
# TODO: See https://github.com/numenta/nupic/issues/2072
spInput = colors[i]
onCells = numpy.zeros(columnDimensions)
onCells = numpy.zeros(columnDimensions, dtype=numpy.uint32)
spImpl.compute(spInput, True, onCells)
spOutput.append(onCells.tolist())
activeCoincIndices = set(onCells.nonzero()[0])
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/nupic/algorithms/spatial_pooler_cpp_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,38 @@ def testUpdateDutyCycles(self):
self.assertEqual(list(resultOverlapArr2), list(trueOverlapArr2))


def testComputeParametersValidation(self):
sp = SpatialPooler(inputDimensions=[5], columnDimensions=[5])
inputGood = np.ones(5, dtype=uintDType)
outGood = np.zeros(5, dtype=uintDType)
inputBad = np.ones(5, dtype=realDType)
inputBad2D = np.ones((5, 5), dtype=realDType)
outBad = np.zeros(5, dtype=realDType)
outBad2D = np.zeros((5, 5), dtype=realDType)

# Validate good parameters
sp.compute(inputGood, False, outGood)

# Validate bad parameters
with self.assertRaises(RuntimeError):
sp.compute(inputBad, False, outBad)

# Validate bad input
with self.assertRaises(RuntimeError):
sp.compute(inputBad, False, outGood)

# Validate bad 2d input
with self.assertRaises(RuntimeError):
sp.compute(inputBad2D, False, outGood)

# Validate bad output
with self.assertRaises(RuntimeError):
sp.compute(inputGood, False, outBad)

# Validate bad 2d output
with self.assertRaises(RuntimeError):
sp.compute(inputGood, False, outBad2D)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/unit/nupic/algorithms/spatial_pooler_py_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def setUp(self):

def testCompute(self):
# Check that there are no errors in call to compute
inputVector = numpy.ones(5)
activeArray = numpy.zeros(5)
inputVector = numpy.ones(5, dtype=uintType)
activeArray = numpy.zeros(5, dtype=uintType)
self.sp.compute(inputVector, True, activeArray)


Expand Down

0 comments on commit af7735f

Please sign in to comment.