Skip to content

Commit

Permalink
Merge pull request #56 from joezuntz/fix_wrong_array_get
Browse files Browse the repository at this point in the history
Fix hard crash when using wrong array types with put/get
  • Loading branch information
marcpaterno authored Nov 7, 2022
2 parents 63e87a3 + f281cd0 commit fe9b518
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 19 deletions.
53 changes: 35 additions & 18 deletions cosmosis/datablock/c_datablock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,16 @@ extern "C"
if (sz == nullptr) return DBS_SIZE_NULL;

auto p = static_cast<DataBlock *>(s);
vector<int> const& r = p->view<vector<int>>(section, name);
*sz = r.size();
if (r.size() > static_cast<size_t>(maxsize)) return DBS_SIZE_INSUFFICIENT;
std::copy(r.cbegin(), r.cend(), val);
try {
vector<int> const& r = p->view<vector<int>>(section, name);
*sz = r.size();
if (r.size() > static_cast<size_t>(maxsize)) return DBS_SIZE_INSUFFICIENT;
std::copy(r.cbegin(), r.cend(), val);
}
catch (DataBlock::BadDataBlockAccess const&) { return DBS_SECTION_NOT_FOUND; }
catch (Section::BadSectionAccess const&) { return DBS_NAME_NOT_FOUND; }
catch (Entry::BadEntry const&) { return DBS_WRONG_VALUE_TYPE; }
catch (...) { return DBS_LOGIC_ERROR; }
// If we are asked to clear out the remainder of the input buffer,
// the following line should be used.
// std::fill(val + *sz, val+maxsize, 0);
Expand All @@ -590,10 +596,16 @@ extern "C"
if (sz == nullptr) return DBS_SIZE_NULL;

auto p = static_cast<DataBlock *>(s);
vector<double> const& r = p->view<vector<double>>(section, name);
*sz = r.size();
if (r.size() > static_cast<size_t>(maxsize)) return DBS_SIZE_INSUFFICIENT;
std::copy(r.cbegin(), r.cend(), val);
try {
vector<double> const& r = p->view<vector<double>>(section, name);
*sz = r.size();
if (r.size() > static_cast<size_t>(maxsize)) return DBS_SIZE_INSUFFICIENT;
std::copy(r.cbegin(), r.cend(), val);
}
catch (DataBlock::BadDataBlockAccess const&) { return DBS_SECTION_NOT_FOUND; }
catch (Section::BadSectionAccess const&) { return DBS_NAME_NOT_FOUND; }
catch (Entry::BadEntry const&) { return DBS_WRONG_VALUE_TYPE; }
catch (...) { return DBS_LOGIC_ERROR; }
// If we are asked to clear out the remainder of the input buffer,
// the following line should be used.
// std::fill(val + *sz, val+maxsize, 0);
Expand All @@ -616,14 +628,21 @@ extern "C"
if (sz == nullptr) return DBS_SIZE_NULL;

auto p = static_cast<DataBlock *>(s);
vector<complex_t> const& r = p->view<vector<complex_t>>(section, name);
*sz = r.size();
if (r.size() > static_cast<size_t>(maxsize)) return DBS_SIZE_INSUFFICIENT;
//std::copy(r.cbegin(), r.cend(), val);
for (size_t i = 0, n = r.size(); i != n; ++i)
{
val[i] = from_complex(r[i]);
}
try{
vector<complex_t> const& r = p->view<vector<complex_t>>(section, name);
*sz = r.size();
if (r.size() > static_cast<size_t>(maxsize)) return DBS_SIZE_INSUFFICIENT;
//std::copy(r.cbegin(), r.cend(), val);
for (size_t i = 0, n = r.size(); i != n; ++i)
{
val[i] = from_complex(r[i]);
}
}
catch (DataBlock::BadDataBlockAccess const&) { return DBS_SECTION_NOT_FOUND; }
catch (Section::BadSectionAccess const&) { return DBS_NAME_NOT_FOUND; }
catch (Entry::BadEntry const&) { return DBS_WRONG_VALUE_TYPE; }
catch (...) { return DBS_LOGIC_ERROR; }


// If we are asked to clear out the remainder of the input buffer,
// the following line should be used.
Expand Down Expand Up @@ -652,8 +671,6 @@ extern "C"
for (int i=0; i<*sz; i++){
val[i] = strdup(r[i].c_str());
}

*sz = r.size();
}
catch (DataBlock::BadDataBlockAccess const&) { return DBS_SECTION_NOT_FOUND; }
catch (Section::BadSectionAccess const&) { return DBS_NAME_NOT_FOUND; }
Expand Down
31 changes: 30 additions & 1 deletion cosmosis/test/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,37 @@ def test_keys():
assert k in b


def test_wrong_array_type():
puts = {
int: "put_int_array_1d",
float: "put_double_array_1d",
str: "put_string_array_1d",
}
gets = {
int: "get_int_array_1d",
float: "get_double_array_1d",
str: "get_string_array_1d",
}
dtypes = list(puts.keys())

for d1 in dtypes[:]:
for d2 in dtypes[:]:
if d1 is d2:
continue

b = DataBlock()
section = 'section'
key = 'key'
put = getattr(b, puts[d1])
get = getattr(b, gets[d2])

value = np.array([1, 2, 3], dtype=d1)
put(section, key, value)
with pytest.raises(errors.BlockWrongValueType):
get(section, key)


if __name__ == '__main__':
# test_string_array()
test_string_array_save()
# test_string_array_save()
test_wrong_array_type()

0 comments on commit fe9b518

Please sign in to comment.