diff --git a/src/checker.cpp b/src/checker.cpp index 4ff7d05..1fdad3a 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -52,15 +52,16 @@ int Checker::masking(const u8 *in, int ilen, u8 *out, int olen, const char *mask out[oofs] = 0; return oofs; } -const char *Checker::filter(const char *in, int ilen, char *out, int *olen, const char *mask) { +const char *Checker::filter(const char *in, int ilen, char *out, int *olen, const char *mask, ContextChecker checker) { if (mask == nullptr) { mask = "?"; } int iofs = 0; int msz = strnlen(mask, MAX_FILTER_STRING); int oofs = 0, tmp; + void *ctx; const u8 *iptr = reinterpret_cast(in); u8 *optr = reinterpret_cast(out); while (iofs < ilen) { - if (trie_.get(iptr + iofs, ilen - iofs, &tmp) != nullptr) { + if ((ctx = trie_.get(iptr + iofs, ilen - iofs, &tmp)) != nullptr && checker(in, ilen, iofs, tmp, ctx)) { oofs += masking(iptr + iofs, tmp, optr + oofs, *olen - oofs, mask, msz); iofs += tmp; } else { @@ -76,12 +77,14 @@ const char *Checker::filter(const char *in, int ilen, char *out, int *olen, cons out[*olen] = 0; return out; } -bool Checker::should_filter(const char *in, int ilen, int *start, int *count, void **pctx) { +bool Checker::should_filter(const char *in, int ilen, int *start, int *count, void **pctx, ContextChecker checker) { int iofs = 0, tmp; + void *ctx; const u8 *iptr = reinterpret_cast(in); while (iofs < ilen) { - if ((*pctx = trie_.get(iptr + iofs, ilen - iofs, count)) != nullptr) { + if ((ctx = trie_.get(iptr + iofs, ilen - iofs, count)) != nullptr && checker(in, ilen, iofs, *count, ctx)) { *start = iofs; + *pctx = ctx; return true; } else { tmp = utf8::peek(iptr + iofs, ilen - iofs); diff --git a/src/checker.h b/src/checker.h index a3a8c9f..4e1e196 100644 --- a/src/checker.h +++ b/src/checker.h @@ -10,6 +10,7 @@ class WordChecker; class Checker { public: typedef shutup_allocator Allocator; + typedef std::function ContextChecker; static const int MAX_FILTER_STRING = 1024 * 1024;//1M class Mempool : public IMempool { Allocator alloc_; @@ -42,8 +43,9 @@ class Checker { void ignore_glyphs(const char *glyphs); inline void add_word(const char *s, void *ctx) { trie_.add(s, ctx); } inline void remove(const char *s) { trie_.remove(s); } - const char *filter(const char *in, int ilen, char *out, int *olen, const char *mask = nullptr); - bool should_filter(const char *in, int ilen, int *start, int *count, void **pctx = nullptr); + static bool truer(const char *in, int ilen, int start, int count, void *ctx) { return true; } + const char *filter(const char *in, int ilen, char *out, int *olen, const char *mask = nullptr, ContextChecker checker = truer); + bool should_filter(const char *in, int ilen, int *start, int *count, void **pctx = nullptr, ContextChecker checker = truer); public: static language::WordChecker *by(const char *lang, Mempool &m); template static C *new_word_checker(Mempool &m) { diff --git a/src/shutup.cpp b/src/shutup.cpp index cf51a9a..cac65c9 100644 --- a/src/shutup.cpp +++ b/src/shutup.cpp @@ -28,14 +28,15 @@ void shutup_add_word(shutter s, const char *word, void *ctx) { c->add(word, ctx); } } -void *shutup_should_filter(shutter s, const char *in, int ilen, int *start, int *count) { +void *shutup_should_filter(shutter s, const char *in, int ilen, int *start, int *count, + shutup_context_checker checker) { void *p; shutup::Checker *c = reinterpret_cast(s); if (!c->valid()) { *start = -1; return nullptr; } - if (c->should_filter(in, ilen, start, count, &p)) { + if (c->should_filter(in, ilen, start, count, &p, checker)) { //shutup_log("should filter: length: %d [%s]\n", *olen, out); return p; } @@ -54,7 +55,8 @@ static char *allocate(int *size) { *size = s_buff_size; return s_buff; } -const char *shutup_filter(shutter s, const char *in, int ilen, char *out, int *olen, const char *mask) { +const char *shutup_filter(shutter s, const char *in, int ilen, char *out, int *olen, const char *mask, + shutup_context_checker checker) { int tmp; shutup::Checker *c = reinterpret_cast(s); if (!c->valid()) { @@ -69,7 +71,7 @@ const char *shutup_filter(shutter s, const char *in, int ilen, char *out, int *o return nullptr; } } - if (c->filter(in, ilen, out, olen, mask) != nullptr) { + if (c->filter(in, ilen, out, olen, mask, checker) != nullptr) { out[*olen] = 0; return reinterpret_cast(out); } diff --git a/src/shutup.h b/src/shutup.h index 0d40f6b..b1d74e3 100644 --- a/src/shutup.h +++ b/src/shutup.h @@ -9,13 +9,16 @@ extern "C" { void (*free)(void *); void *(*realloc)(void *, size_t); } shutup_allocator; + typedef bool (*shutup_context_checker)(const char *in, int ilen, int start, int count, void *ctx); extern shutter shutup_new(const char *lang, shutup_allocator *a); extern void shutup_delete(shutter s); extern void shutup_set_alias(shutter s, const char *target, const char *alias); extern void shutup_ignore_glyphs(shutter s, const char *glyphs); extern void shutup_add_word(shutter s, const char *word, void *ctx); - extern void *shutup_should_filter(shutter s, const char *in, int ilen, int *start, int *count); - extern const char *shutup_filter(shutter s, const char *in, int ilen, char *out, int *olen, const char *mask); + extern void *shutup_should_filter(shutter s, const char *in, int ilen, int *start, int *count, + shutup_context_checker checker); + extern const char *shutup_filter(shutter s, const char *in, int ilen, char *out, int *olen, const char *mask, + shutup_context_checker checker); //for debugging typedef void (*shutup_logger)(const char *); extern void shutup_set_logger(shutup_logger logger); diff --git a/test/checker.cpp b/test/checker.cpp index 15712b2..008c737 100644 --- a/test/checker.cpp +++ b/test/checker.cpp @@ -14,6 +14,7 @@ struct testcase { const char *matched_; const char *expect_; void *ctx_; + bool (*checker_)(const char *in, int ilen, int start, int count, void *ctx); }; struct filter { const char *text_; @@ -36,9 +37,10 @@ struct testcase { } for (auto &i : inputs_) { int ilen = std::strlen(i.text_); + auto checker = (i.checker_ == nullptr ? shutup::Checker::truer : i.checker_); int start, count; void *ctx; - if (i.filtered_ != c.should_filter(i.text_, std::strlen(i.text_), &start, &count, &ctx)) { + if (i.filtered_ != c.should_filter(i.text_, std::strlen(i.text_), &start, &count, &ctx, checker)) { TRACE("input:[%s]\n", i.text_); return "text should be filtered but actually not"; } @@ -47,15 +49,15 @@ struct testcase { char buff[count + 1]; std::memcpy(buff, i.text_ + start, count); buff[count] = 0; if (std::strcmp(i.matched_, buff) != 0 || ctx != i.ctx_) { - TRACE("filtered:[%s] [%s]\n", i.matched_, buff); + TRACE("filtered:[%s] [%s] %p %p\n", i.matched_, buff, ctx, i.ctx_); return "filtered but match part does not match expected"; } } int olen = ilen * shutup::utf8::MAX_BYTE_PER_GRYPH; char buff[olen]; const char *r = mask_ == nullptr ? - c.filter(i.text_, std::strlen(i.text_), buff, &olen) : - c.filter(i.text_, std::strlen(i.text_), buff, &olen, mask_); + c.filter(i.text_, std::strlen(i.text_), buff, &olen, nullptr, checker) : + c.filter(i.text_, std::strlen(i.text_), buff, &olen, mask_, checker); if (std::strcmp(i.expect_, r) != 0) { TRACE("filter:[%s] => [%s]\n", i.expect_, r); return "filter result does not match expected"; @@ -70,6 +72,30 @@ struct testcase { static void *p(int id) { return reinterpret_cast(id); } +static bool p3_is_forward_match(const char *in, int ilen, int start, int count, void *ctx) { + if (ctx == p(3)) { + if (start == 0) { + return true; + } + } + return false; +} +static bool p3_is_backward_match(const char *in, int ilen, int start, int count, void *ctx) { + if (ctx == p(3)) { + if ((start + count) == ilen) { + return true; + } + } + return false; +} +static bool p3_is_exact_match(const char *in, int ilen, int start, int count, void *ctx) { + if (ctx == p(3)) { + if (start == 0 && count == ilen) { + return true; + } + } + return false; +} extern const char *checker_test() { std::vector cases{ @@ -102,6 +128,26 @@ extern const char *checker_test() { p(1), }, {"OK:「ンビエト」、NG:「ワソド」", true, ":「ワソド", "OK:「ンビエト」、NG:「???」", p(5)}, + + {"これは馬津怒輪亜怒ですか", false, "", "これは馬津怒輪亜怒ですか", nullptr, p3_is_exact_match}, + {"これは馬津怒輪亜怒ですか", false, "", "これは馬津怒輪亜怒ですか", nullptr, p3_is_backward_match}, + {"これは馬津怒輪亜怒ですか", false, "", "これは馬津怒輪亜怒ですか", nullptr, p3_is_forward_match}, + {"これは馬津怒輪亜怒ですか", true, "馬津怒輪亜怒", "これは??????ですか", p(3)}, + + {"馬津怒輪亜怒ですか", false, "", "馬津怒輪亜怒ですか", nullptr, p3_is_exact_match}, + {"馬津怒輪亜怒ですか", false, "", "馬津怒輪亜怒ですか", nullptr, p3_is_backward_match}, + {"馬津怒輪亜怒ですか", true, "馬津怒輪亜怒", "??????ですか", p(3), p3_is_forward_match}, + {"馬津怒輪亜怒ですか", true, "馬津怒輪亜怒", "??????ですか", p(3)}, + + {"これは馬津怒輪亜怒", false, "", "これは馬津怒輪亜怒", nullptr, p3_is_exact_match}, + {"これは馬津怒輪亜怒", true, "馬津怒輪亜怒", "これは??????", p(3), p3_is_backward_match}, + {"これは馬津怒輪亜怒", false, "", "これは馬津怒輪亜怒", nullptr, p3_is_forward_match}, + {"これは馬津怒輪亜怒", true, "馬津怒輪亜怒", "これは??????", p(3)}, + + {"馬津怒輪亜怒", true, "馬津怒輪亜怒", "??????", p(3), p3_is_exact_match}, + {"馬津怒輪亜怒", true, "馬津怒輪亜怒", "??????", p(3), p3_is_backward_match}, + {"馬津怒輪亜怒", true, "馬津怒輪亜怒", "??????", p(3), p3_is_forward_match}, + {"馬津怒輪亜怒", true, "馬津怒輪亜怒", "??????", p(3)}, }, }, };