Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

paramdict value string type, natural array representation #5915

Merged
merged 11 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/developer-guide/param-and-model-file-structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ the meaning of existing param key index can be looked up at [operation-param-wei
* integer array value : [array size],int,int,...,int
* float array value : [array size],float,float,...,float

In modern ncnn param file

* array could be represented as `3=2.0,3.0` that is much more human friendly
* string typed value: `4=hello` and the string is no longer than 255

## net.bin
```
+---------+---------+---------+---------+---------+---------+
Expand Down
193 changes: 183 additions & 10 deletions src/paramdict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ class ParamDictPrivate
// 4 = array of int/float
// 5 = array of int
// 6 = array of float
// 7 = string
int type;
union
{
int i;
float f;
};
Mat v;
std::string s;
} params[NCNN_MAX_PARAM_COUNT];
};

Expand All @@ -70,6 +72,10 @@ ParamDict::ParamDict(const ParamDict& rhs)
{
d->params[i].i = rhs.d->params[i].i;
}
else if (type == 7)
{
d->params[i].s = rhs.d->params[i].s;
}
else // if (type == 4 || type == 5 || type == 6)
{
d->params[i].v = rhs.d->params[i].v;
Expand All @@ -90,6 +96,10 @@ ParamDict& ParamDict::operator=(const ParamDict& rhs)
{
d->params[i].i = rhs.d->params[i].i;
}
else if (type == 7)
{
d->params[i].s = rhs.d->params[i].s;
}
else // if (type == 4 || type == 5 || type == 6)
{
d->params[i].v = rhs.d->params[i].v;
Expand Down Expand Up @@ -120,6 +130,11 @@ Mat ParamDict::get(int id, const Mat& def) const
return d->params[id].type ? d->params[id].v : def;
}

std::string ParamDict::get(int id, const std::string& def) const
{
return d->params[id].type ? d->params[id].s : def;
}

void ParamDict::set(int id, int i)
{
d->params[id].type = 2;
Expand All @@ -138,12 +153,20 @@ void ParamDict::set(int id, const Mat& v)
d->params[id].v = v;
}

void ParamDict::set(int id, const std::string& s)
{
d->params[id].type = 7;
d->params[id].s = s;
}

void ParamDict::clear()
{
for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
{
d->params[i].type = 0;
d->params[i].i = 0;
d->params[i].v = Mat();
d->params[i].s.clear();
}
}

Expand All @@ -163,6 +186,11 @@ static bool vstr_is_float(const char vstr[16])
return false;
}

static bool vstr_is_string(const char vstr[16])
{
return isalpha(vstr[0]);
}

static float vstr_to_float(const char vstr[16])
{
double v = 0.0;
Expand Down Expand Up @@ -247,7 +275,8 @@ int ParamDict::load_param(const DataReader& dr)
{
clear();

// 0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
// 0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
// 3=0.1,0.2,0.4,0.8,1.0

// parse each key=value pair
int id = 0;
Expand All @@ -267,6 +296,7 @@ int ParamDict::load_param(const DataReader& dr)

if (is_array)
{
// old style array
int len = 0;
int nscan = dr.scan("%d", &len);
if (nscan != 1)
Expand Down Expand Up @@ -307,19 +337,120 @@ int ParamDict::load_param(const DataReader& dr)

d->params[id].type = is_float ? 6 : 5;
}

continue;
}
else

char vstr[16];
char comma[4];
int nscan = dr.scan("%15[^,\n ]", vstr);
if (nscan != 1)
{
char vstr[16];
int nscan = dr.scan("%15s", vstr);
if (nscan != 1)
NCNN_LOGE("ParamDict read value failed");
return -1;
}

bool is_string = vstr_is_string(vstr);
if (is_string)
{
// scan the remaining string
char vstr2[256];
vstr2[241] = '\0'; // max 255 = 15 + 240
nscan = dr.scan("%255[^\n ]", vstr2);
if (nscan == 1)
{
NCNN_LOGE("ParamDict read value failed");
return -1;
if (vstr2[241] != '\0')
{
NCNN_LOGE("string too long (id=%d)", id);
return -1;
}

d->params[id].s = std::string(vstr) + vstr2;
}
else
{
d->params[id].s = std::string(vstr);
}

d->params[id].type = 7;

continue;
}

bool is_float = vstr_is_float(vstr);

nscan = dr.scan("%1[,]", comma);
is_array = nscan == 1;

if (is_array)
{
std::vector<float> af;
std::vector<int> ai;

if (is_float)
{
af.push_back(vstr_to_float(vstr));
}
else
{
int v = 0;
nscan = sscanf(vstr, "%d", &v);
if (nscan != 1)
{
NCNN_LOGE("ParamDict parse value failed");
return -1;
}

ai.push_back(v);
}

while (1)
{
nscan = dr.scan("%15[^,\n ]", vstr);
if (nscan != 1)
{
break;
}

if (is_float)
{
af.push_back(vstr_to_float(vstr));
}
else
{
int v = 0;
nscan = sscanf(vstr, "%d", &v);
if (nscan != 1)
{
NCNN_LOGE("ParamDict parse value failed");
return -1;
}

ai.push_back(v);
}

nscan = dr.scan("%1[,]", comma);
if (nscan != 1)
{
break;
}
}

bool is_float = vstr_is_float(vstr);
if (is_float)
{
d->params[id].v.create((int)af.size());
memcpy(d->params[id].v.data, af.data(), af.size() * 4);
}
else
{
d->params[id].v.create((int)ai.size());
memcpy(d->params[id].v.data, ai.data(), ai.size() * 4);
}

d->params[id].type = is_float ? 6 : 5;
}
else
{
if (is_float)
{
d->params[id].f = vstr_to_float(vstr);
Expand Down Expand Up @@ -375,7 +506,12 @@ int ParamDict::load_param_bin(const DataReader& dr)
while (id != -233)
{
bool is_array = id <= -23300;
if (is_array)
bool is_string = id <= -23400;
if (is_string)
{
id = -id - 23400;
}
else if (is_array)
{
id = -id - 23300;
}
Expand All @@ -386,7 +522,44 @@ int ParamDict::load_param_bin(const DataReader& dr)
return -1;
}

if (is_array)
if (is_string)
{
int len = 0;
nread = dr.read(&len, sizeof(int));
if (nread != sizeof(int))
{
NCNN_LOGE("ParamDict read array length failed %zd", nread);
return -1;
}

#if __BIG_ENDIAN__
swap_endianness_32(&len);
#endif

if (len > 255)
{
NCNN_LOGE("string too long (id=%d)", id);
return -1;
}

size_t len_padded = (len + 3) / 4 * 4;
std::vector<char> tmpstr(len_padded + 1);

char* ptr = (char*)tmpstr.data();
nread = dr.read(ptr, len_padded);
if (nread != len_padded)
{
NCNN_LOGE("ParamDict read string failed %zd", nread);
return -1;
}

tmpstr[len_padded] = '\0';

d->params[id].s = tmpstr.data();

d->params[id].type = 7;
}
else if (is_array)
{
int len = 0;
nread = dr.read(&len, sizeof(int));
Expand Down
4 changes: 4 additions & 0 deletions src/paramdict.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,17 @@ class NCNN_EXPORT ParamDict
float get(int id, float def) const;
// get array
Mat get(int id, const Mat& def) const;
// get string
std::string get(int id, const std::string& def) const;

// set int
void set(int id, int i);
// set float
void set(int id, float f);
// set array
void set(int id, const Mat& v);
// set string
void set(int id, const std::string& s);

protected:
friend class Net;
Expand Down
10 changes: 3 additions & 7 deletions src/simplestl.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,15 +536,11 @@ struct NCNN_EXPORT string : public vector<char>
}
bool operator==(const string& str2) const
{
return strcmp(data_, str2.data_) == 0;
return size_ == str2.size_ && strncmp(data_, str2.data_, size_) == 0;
}
bool operator==(const char* str2) const
bool operator!=(const string& str2) const
{
return strcmp(data_, str2) == 0;
}
bool operator!=(const char* str2) const
{
return strcmp(data_, str2) != 0;
return size_ != str2.size_ || strncmp(data_, str2.data_, size_) != 0;
}
string& operator+=(const string& str1)
{
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ endif()

ncnn_add_test(c_api)
ncnn_add_test(cpu)
ncnn_add_test(paramdict)

if(NCNN_VULKAN)
ncnn_add_test(command)
Expand Down
Loading
Loading