diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c97235d97a0..19bd020c7db 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ set(ncnn_SRCS command.cpp cpu.cpp datareader.cpp + expression.cpp gpu.cpp layer.cpp mat.cpp diff --git a/src/expression.cpp b/src/expression.cpp new file mode 100644 index 00000000000..fe299edd4a7 --- /dev/null +++ b/src/expression.cpp @@ -0,0 +1,544 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "expression.h" + +namespace ncnn { + +int count_expression_blobs(const std::string& expr) +{ + int count = 0; + + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '(' || ch == ')' || ch == ',') + { + if (!t.empty()) + { + if (t.size() == 2 && (t[0] >= '0' && t[0] <= '9') && (t[1] == 'w' || t[1] == 'h' || t[1] == 'd' || t[1] == 'c')) + { + int blob_index = t[0] - '0'; + count = std::max(count, blob_index + 1); + } + + t.clear(); + } + } + else + { +#if NCNN_SIMPLESTL + t.resize(t.size() + 1); + t[t.size() - 1] = ch; +#else + t += ch; +#endif + } + } + + if (!t.empty()) + { + if (t.size() == 2 && (t[0] >= '0' && t[0] <= '9') && (t[1] == 'w' || t[1] == 'h' || t[1] == 'd' || t[1] == 'c')) + { + int blob_index = t[0] - '0'; + count = std::max(count, blob_index + 1); + } + } + + return count; +} + +std::vector eval_list_expression(const std::string& expr, const std::vector& blobs) +{ + // /(0w,2),*(0h,2),0c + + // split by , ( ) + // + // / + // 0w + // 2 + // ------------------- + // * + // 0h + // 2 + // ------------------- + // 0c + // ------------------- + + // split by , ( ) + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '(' || ch == ')' || ch == ',') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { +#if NCNN_SIMPLESTL + t.resize(t.size() + 1); + t[t.size() - 1] = ch; +#else + t += ch; +#endif + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } + + // / 0w 2 * 0h 2 0c + + struct typed_value + { + int type; // 0=i 1=f + union + { + int i; + float f; + }; + + typed_value() + : type(0), i(0) + { + } + typed_value(int _i) + : type(0), i(_i) + { + } + typed_value(float _f) + : type(1), f(_f) + { + } + + int to_int() + { + if (type == 0) + return i; + + // trunc by default + return (int)f; + } + }; + + // scan and stack + std::vector exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string& t = tokens[i]; + + // + - * / 0w 0h 0d 0c 12345 + + if (t.size() == 2 && (t[0] >= '0' && t[0] <= '9') && (t[1] == 'w' || t[1] == 'h' || t[1] == 'd' || t[1] == 'c')) + { + size_t blob_index = t[0] - '0'; + if (blob_index >= blobs.size()) + { + NCNN_LOGE("shape expression blob index %d out of bound!", blob_index); + blob_index = 0; + } + + const Mat& blob = blobs[blob_index]; + int size; + if (t[1] == 'w') + size = blob.w; + else if (t[1] == 'h') + size = blob.h; + else if (t[1] == 'd') + size = blob.d; + else // if (t[1] == 'c') + size = blob.c; + + exprstack.push_back(size); + } + else if (t == "+" || t == "-" || t == "*" || t == "/" || t == "max" || t == "min") + { +#if NCNN_SIMPLESTL + typed_value ta = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); + typed_value tb = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); +#else + typed_value ta = exprstack.back(); + exprstack.pop_back(); + typed_value tb = exprstack.back(); + exprstack.pop_back(); +#endif + + if (ta.type == 0 && tb.type == 0) + { + int a = ta.i; + int b = tb.i; + + if (t == "+") + { + exprstack.push_back(a + b); + } + if (t == "-") + { + exprstack.push_back(a - b); + } + if (t == "*") + { + exprstack.push_back(a * b); + } + if (t == "/") + { + if (b == 0) + { + NCNN_LOGE("expr divide by zero"); + exprstack.push_back(a); + } + else + { + exprstack.push_back(a / b); + } + } + if (t == "max") + { + exprstack.push_back(std::max(a, b)); + } + if (t == "min") + { + exprstack.push_back(std::min(a, b)); + } + } + else + { + float a = ta.type == 0 ? ta.i : ta.f; + float b = tb.type == 0 ? tb.i : tb.f; + + if (t == "+") + { + exprstack.push_back(a + b); + } + if (t == "-") + { + exprstack.push_back(a - b); + } + if (t == "*") + { + exprstack.push_back(a * b); + } + if (t == "/") + { + exprstack.push_back(a / b); + } + if (t == "max") + { + exprstack.push_back(std::max(a, b)); + } + if (t == "min") + { + exprstack.push_back(std::min(a, b)); + } + } + } + else if (t == "abs" || t == "neg" || t == "sign" || t == "square") + { +#if NCNN_SIMPLESTL + typed_value ta = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); +#else + typed_value ta = exprstack.back(); + exprstack.pop_back(); +#endif + + if (ta.type == 0) + { + int a = ta.i; + + if (t == "abs") + { + exprstack.push_back(a > 0 ? a : -a); + } + if (t == "neg") + { + exprstack.push_back(-a); + } + if (t == "sign") + { + exprstack.push_back(a > 0 ? 1 : (a == 0 ? 0 : -1)); + } + if (t == "square") + { + exprstack.push_back(a * a); + } + } + else + { + float a = ta.f; + + if (t == "abs") + { + exprstack.push_back(fabsf(a)); + } + if (t == "neg") + { + exprstack.push_back(-a); + } + if (t == "sign") + { + exprstack.push_back(a > 0.f ? 1 : (a == 0.f ? 0 : -1)); + } + if (t == "square") + { + exprstack.push_back(a * a); + } + } + } + else if (t == "trunc" || t == "ceil" || t == "floor" || t == "round") + { +#if NCNN_SIMPLESTL + typed_value ta = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); +#else + typed_value ta = exprstack.back(); + exprstack.pop_back(); +#endif + + if (ta.type == 0) + { + int a = ta.i; + exprstack.push_back(a); + } + else + { + float a = ta.f; + + if (t == "neg") + { + exprstack.push_back(-a); + } + if (t == "trunc") + { + exprstack.push_back((int)a); + } + if (t == "ceil") + { + exprstack.push_back((int)ceil(a)); + } + if (t == "floor") + { + exprstack.push_back((int)floor(a)); + } + if (t == "round") + { + exprstack.push_back((int)round(a)); + } + } + } + else if (t == "acos" + || t == "acosh" + || t == "asin" + || t == "asinh" + || t == "atan" + || t == "atanh" + || t == "cos" + || t == "cosh" + || t == "erf" + || t == "exp" + || t == "log" + || t == "log10" + || t == "reciprocal" + || t == "rsqrt" + || t == "sin" + || t == "sinh" + || t == "sqrt" + || t == "tan" + || t == "tanh") + { +#if NCNN_SIMPLESTL + typed_value ta = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); +#else + typed_value ta = exprstack.back(); + exprstack.pop_back(); +#endif + + float a = ta.type == 0 ? ta.i : ta.f; + + if (t == "acos") + { + exprstack.push_back(acosf(a)); + } + if (t == "acosh") + { + exprstack.push_back(acoshf(a)); + } + if (t == "asin") + { + exprstack.push_back(asinf(a)); + } + if (t == "asinh") + { + exprstack.push_back(asinhf(a)); + } + if (t == "atan") + { + exprstack.push_back(atanf(a)); + } + if (t == "atanh") + { + exprstack.push_back(atanhf(a)); + } + if (t == "cos") + { + exprstack.push_back(cosf(a)); + } + if (t == "cosh") + { + exprstack.push_back(coshf(a)); + } + if (t == "erf") + { + exprstack.push_back(erff(a)); + } + if (t == "exp") + { + exprstack.push_back(expf(a)); + } + if (t == "log") + { + exprstack.push_back(logf(a)); + } + if (t == "log10") + { + exprstack.push_back(log10f(a)); + } + if (t == "reciprocal") + { + exprstack.push_back(1.f / a); + } + if (t == "rsqrt") + { + exprstack.push_back(1.f / sqrtf(a)); + } + if (t == "tan") + { + exprstack.push_back(tanf(a)); + } + if (t == "tanh") + { + exprstack.push_back(tanhf(a)); + } + } + else if (t == "atan2" + || t == "fmod" + || t == "pow" + || t == "remainder" + || t == "logaddexp") + { +#if NCNN_SIMPLESTL + typed_value ta = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); + typed_value tb = exprstack[exprstack.size() - 1]; + exprstack.resize(exprstack.size() - 1); +#else + typed_value ta = exprstack.back(); + exprstack.pop_back(); + typed_value tb = exprstack.back(); + exprstack.pop_back(); +#endif + + float a = ta.type == 0 ? ta.i : ta.f; + float b = tb.type == 0 ? tb.i : tb.f; + + if (t == "atan2") + { + exprstack.push_back(atan2f(a, b)); + } + if (t == "fmod") + { + exprstack.push_back(fmodf(a, b)); + } + if (t == "pow") + { + exprstack.push_back(powf(a, b)); + } + if (t == "remainder") + { + float r = fmodf(a, b); + if (a * b < 0) + r += b; + exprstack.push_back(r); + } + if (t == "logaddexp") + { + exprstack.push_back(logf(expf(a) + expf(b))); + } + } + else + { + // literal + int vi; + float vf; + int nscani = sscanf(t.c_str(), "%d", &vi); + int nscanf = sscanf(t.c_str(), "%f", &vf); + if (nscani == 1 && nscanf == 1 && vi == vf) + { + exprstack.push_back(vi); + } + else if (nscanf == 1) + { + exprstack.push_back(vf); + } + else + { + NCNN_LOGE("malformed literal token %s", t.c_str()); + exprstack.push_back(0); + } + } + } + + std::vector list; +#if NCNN_SIMPLESTL + int size = exprstack[exprstack.size() - 1].to_int(); + exprstack.resize(exprstack.size() - 1); +#else + int size = exprstack.back().to_int(); + exprstack.pop_back(); +#endif + list.push_back(size); + while (!exprstack.empty()) + { +#if NCNN_SIMPLESTL + size = exprstack[exprstack.size() - 1].to_int(); + exprstack.resize(exprstack.size() - 1); +#else + size = exprstack.back().to_int(); + exprstack.pop_back(); +#endif + list.push_back(size); + } + + return list; +} + +} // namespace ncnn diff --git a/src/expression.h b/src/expression.h new file mode 100644 index 00000000000..1df09aaf8da --- /dev/null +++ b/src/expression.h @@ -0,0 +1,31 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "mat.h" + +namespace ncnn { + +// count how many blobs are referenced inside expression +int count_expression_blobs(const std::string& expr); + +// resolve reshape shape from expression and input blobs +// resolve slice indices(starts, ends) from expression and input blobs +// supported binary operator: + - * / max min +// supported unary operator: neg trunc ceil floor round +// expr = "/(0w,2),*(1h,2),0c" +// blobs = (A, B) +// outlist = (A.w/2, B.h*2, A.c) +std::vector eval_list_expression(const std::string& expr, const std::vector& blobs); + +} // namespace ncnn diff --git a/src/layer/reshape.cpp b/src/layer/reshape.cpp index b35f5971056..3f07186f29d 100644 --- a/src/layer/reshape.cpp +++ b/src/layer/reshape.cpp @@ -14,6 +14,8 @@ #include "reshape.h" +#include "expression.h" + namespace ncnn { Reshape::Reshape() @@ -30,6 +32,11 @@ int Reshape::load_param(const ParamDict& pd) c = pd.get(2, -233); permute = pd.get(3, 0); + if (permute == 1) + { + NCNN_LOGE("reshape permute is deprecated, and will be removed"); + } + ndim = 4; if (d == -233) ndim = 3; @@ -40,23 +47,24 @@ int Reshape::load_param(const ParamDict& pd) if (w == -233) ndim = 0; + shape_expr = pd.get(6, ""); + + // count reference blobs + if (!shape_expr.empty() && count_expression_blobs(shape_expr) > 1) + { + one_blob_only = false; + } + return 0; } -int Reshape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +static int reshape(const Mat& bottom_blob, Mat& top_blob, int ndim, int outw, int outh, int outd, int outc, int permute, const Option& opt) { size_t elemsize = bottom_blob.elemsize; int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c; int dims = bottom_blob.dims; - // resolve out shape - - int outw = w; - int outh = h; - int outd = d; - int outc = c; - if (ndim == 1) { if (outw == 0) @@ -351,4 +359,76 @@ int Reshape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c return 0; } +int Reshape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int outw = w; + int outh = h; + int outd = d; + int outc = c; + + // resolve out shape + if (!shape_expr.empty()) + { + eval_shape_expr(bottom_blob, outw, outh, outd, outc); + } + + return reshape(bottom_blob, top_blob, ndim, outw, outh, outd, outc, permute, opt); +} + +int Reshape::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int outw = w; + int outh = h; + int outd = d; + int outc = c; + + // resolve out shape + if (!shape_expr.empty()) + { + eval_shape_expr(bottom_blobs, outw, outh, outd, outc); + } + + return reshape(bottom_blobs[0], top_blobs[0], ndim, outw, outh, outd, outc, permute, opt); +} + +void Reshape::eval_shape_expr(const Mat& bottom_blob, int& outw, int& outh, int& outd, int& outc) const +{ + std::vector bottom_blobs(1); + bottom_blobs[0] = bottom_blob; + eval_shape_expr(bottom_blobs, outw, outh, outd, outc); +} + +void Reshape::eval_shape_expr(const std::vector& bottom_blobs, int& outw, int& outh, int& outd, int& outc) const +{ + // [size(@0,0),size(@0,1),12,64] + std::vector shape = eval_list_expression(shape_expr, bottom_blobs); + + outw = 1; + outh = 1; + outd = 1; + outc = 1; + if (shape.size() == 1) + { + outw = shape[0]; + } + if (shape.size() == 2) + { + outw = shape[0]; + outh = shape[1]; + } + if (shape.size() == 3) + { + outw = shape[0]; + outh = shape[1]; + outc = shape[2]; + } + if (shape.size() == 4) + { + outw = shape[0]; + outh = shape[1]; + outd = shape[2]; + outc = shape[3]; + } +} + } // namespace ncnn diff --git a/src/layer/reshape.h b/src/layer/reshape.h index b1072a48c38..9eed5743d3b 100644 --- a/src/layer/reshape.h +++ b/src/layer/reshape.h @@ -28,6 +28,12 @@ class Reshape : public Layer virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +protected: + void eval_shape_expr(const Mat& bottom_blob, int& outw, int& outh, int& outd, int& outc) const; + void eval_shape_expr(const std::vector& bottom_blobs, int& outw, int& outh, int& outd, int& outc) const; + public: // reshape flag // 0 = copy from bottom @@ -42,6 +48,9 @@ class Reshape : public Layer int permute; int ndim; + + // [size(@0,0),size(@0,1),12,64] + std::string shape_expr; }; } // namespace ncnn diff --git a/src/paramdict.cpp b/src/paramdict.cpp index 4704ea169e6..be7d4ef71da 100644 --- a/src/paramdict.cpp +++ b/src/paramdict.cpp @@ -188,7 +188,7 @@ static bool vstr_is_float(const char vstr[16]) static bool vstr_is_string(const char vstr[16]) { - return isalpha(vstr[0]); + return isalpha(vstr[0]) || vstr[0] == '\"'; } static float vstr_to_float(const char vstr[16]) @@ -356,7 +356,14 @@ int ParamDict::load_param(const DataReader& dr) // scan the remaining string char vstr2[256]; vstr2[241] = '\0'; // max 255 = 15 + 240 - nscan = dr.scan("%255[^\n ]", vstr2); + if (vstr[0] == '\"') + { + nscan = dr.scan("%255[^\"]\"", vstr2); + } + else + { + nscan = dr.scan("%255[^\n ]", vstr2); + } if (nscan == 1) { if (vstr2[241] != '\0') @@ -365,13 +372,22 @@ int ParamDict::load_param(const DataReader& dr) return -1; } - d->params[id].s = std::string(vstr) + vstr2; + if (vstr[0] == '\"') + d->params[id].s = std::string(&vstr[1]) + vstr2; + else + d->params[id].s = std::string(vstr) + vstr2; } else { - d->params[id].s = std::string(vstr); + if (vstr[0] == '\"') + d->params[id].s = std::string(&vstr[1]); + else + d->params[id].s = std::string(vstr); } + if (d->params[id].s[d->params[id].s.size() - 1] == '\"') + d->params[id].s.resize(d->params[id].s.size() - 1); + d->params[id].type = 7; continue; diff --git a/tests/test_paramdict.cpp b/tests/test_paramdict.cpp index 503de188e17..7d101d650d8 100644 --- a/tests/test_paramdict.cpp +++ b/tests/test_paramdict.cpp @@ -210,7 +210,7 @@ static int test_paramdict_1() static int test_paramdict_2() { ParamDictTest pdt; - pdt.load_param("0=bij,bjk->bik 1=This_is_a_very_long_long_string 2=X"); + pdt.load_param("0=bij,bjk->bik 1=This_is_a_very_long_long_string 3=\"1,2,3 and 6.667 zzz\" 2=\"X\""); // string int types = pdt.type(0); @@ -254,6 +254,20 @@ static int test_paramdict_2() return -1; } + // string + types = pdt.type(3); + if (types != 7) + { + fprintf(stderr, "test_paramdict string type failed %d != 7\n", types); + return -1; + } + s = pdt.get(3, ""); + if (s != "1,2,3 and 6.667 zzz") + { + fprintf(stderr, "test_paramdict string text failed %s != \"1,2,3 and 6.667 zzz\"\n", s.c_str()); + return -1; + } + return 0; } @@ -661,8 +675,24 @@ static int test_paramdict_6() return 0; } +#include "expression.h" + int main() { + std::vector blobs(2); + blobs[0].w = 100; + blobs[0].h = 200; + blobs[0].c = 44; + + blobs[1].w = 10; + blobs[1].h = 20; + blobs[1].c = 4; + + // outshape = ( int(a.w * 0.5) + (a.c - 10), floor(b.h / 0.5), a.c + b.c, round(2.0) ) + std::vector outshape = eval_list_expression("+(trunc(*(0w,0.5)),-(0c,10)),floor(/(1h,0.5)),+(0c,1c),round(2.0)", blobs); + + fprintf(stderr, "%d %d %d %d %d\n", (int)outshape.size(), outshape[0], outshape[1], outshape[2], outshape[3]); + return 0 || test_paramdict_0() || test_paramdict_1()