Skip to content

Commit

Permalink
test pd copy assign
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 20, 2025
1 parent de69646 commit 6d1cd86
Showing 1 changed file with 143 additions and 1 deletion.
144 changes: 143 additions & 1 deletion tests/test_paramdict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,147 @@ static int test_paramdict_5()
return 0;
}

static int compare_paramdict(const ncnn::ParamDict& pd, const ncnn::ParamDict& pd0)
{
for (int id = 0; ; id++)
{
const int type0 = pd0.type(id);
if (type0 == 0)
{
break;
}
else if (type0 == 2)
{
const int i0 = pd0.get(id, 0);
int i = pd.get(id, 0);
if (i != i0)
{
fprintf(stderr, "compare_paramdict int failed %d != %d\n", i, i0);
return -1;
}
}
else if (type0 == 3)
{
const float f0 = pd0.get(id, 0.f);
int f = pd.get(id, 0.f);
if (f != f0)
{
fprintf(stderr, "compare_paramdict float failed %f != %f\n", f, f0);
return -1;
}
}
else if (type0 == 5)
{
const ncnn::Mat ai0 = pd0.get(id, ncnn::Mat());
ncnn::Mat ai = pd.get(id, ncnn::Mat());
if (ai.w != ai0.w)
{
fprintf(stderr, "compare_paramdict int array size failed %d != %d\n", ai.w, ai0.w);
return -1;
}
for (int q = 0; q < ai0.w; q++)
{
int i0 = ((const int*)ai0)[q];
int i = ((const int*)ai)[q];
if (i != i0)
{
fprintf(stderr, "compare_paramdict int array element %d failed %d != %d\n", q, i, i0);
return -1;
}
}
}
else if (type0 == 6)
{
const ncnn::Mat af0 = pd0.get(id, ncnn::Mat());
ncnn::Mat af = pd.get(id, ncnn::Mat());
if (af.w != af0.w)
{
fprintf(stderr, "compare_paramdict float array size failed %d != %d\n", af.w, af0.w);
return -1;
}
for (int q = 0; q < af0.w; q++)
{
float f0 = af0[q];
float f = af[q];
if (f != f0)
{
fprintf(stderr, "compare_paramdict float array element %d failed %f != %f\n", q, f, f0);
return -1;
}
}
}
else if (type0 == 7)
{
const std::string s0 = pd0.get(id, "");
std::string s = pd.get(id, "");
if (s != s0)
{
fprintf(stderr, "compare_paramdict string failed %s != %s\n", s.c_str(), s0.c_str());
return -1;
}
}
else
{
fprintf(stderr, "unexpected paramdict type %d\n", type0);
return -1;
}
}

return 0;
}

static int test_paramdict_6()
{
const int i0 = 11;
const float f0 = -2.2f;
const std::string s0 = "qwqwqwq";
ncnn::Mat ai0(1);
{
int* p = ai0;
p[0] = 233;
}

ncnn::Mat af0(4);
{
float* p = af0;
p[0] = 2.33f;
p[1] = -0.2f;
p[2] = 0.f;
p[3] = 9494.f;
}

ncnn::ParamDict pd0;
pd0.set(1, i0);
pd0.set(2, ai0);
pd0.set(3, f0);
pd0.set(4, af0);
pd0.set(5, s0);

// copy
{
ncnn::ParamDict pd(pd0);

int ret = compare_paramdict(pd, pd0);
if (ret != 0)
{
fprintf(stderr, "paramdict copy failed\n");
return -1;
}
}

// assign
{
ncnn::ParamDict pd = pd0;

int ret = compare_paramdict(pd, pd0);
if (ret != 0)
{
fprintf(stderr, "paramdict assign failed\n");
return -1;
}
}
}

int main()
{
return 0
Expand All @@ -505,5 +646,6 @@ int main()
|| test_paramdict_2()
|| test_paramdict_3()
|| test_paramdict_4()
|| test_paramdict_5();
|| test_paramdict_5()
|| test_paramdict_6();
}

0 comments on commit 6d1cd86

Please sign in to comment.