Skip to content

Commit

Permalink
Implement MSC3664, pushrules for related events
Browse files Browse the repository at this point in the history
  • Loading branch information
deepbluev7 committed Dec 3, 2022
1 parent c7a13e7 commit 7155cbb
Show file tree
Hide file tree
Showing 4 changed files with 608 additions and 57 deletions.
13 changes: 12 additions & 1 deletion include/mtx/pushrules.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <variant>
#include <vector>

#include "mtx/events/common.hpp"
#include "mtx/events/power_levels.hpp"

namespace mtx {
Expand Down Expand Up @@ -47,6 +48,12 @@ struct PushCondition
//! defaults to ==.
std::string is;

//! The relation type to match on. Only valid for `im.nheko.msc3664.related_event_match`
//! conditions.
mtx::common::RelationType rel_type = mtx::common::RelationType::Unsupported;
//! Wether to match fallback relations or not.
bool include_fallback = false;

friend void to_json(nlohmann::json &obj, const PushCondition &condition);
friend void from_json(const nlohmann::json &obj, PushCondition &condition);
};
Expand Down Expand Up @@ -200,10 +207,14 @@ class PushRuleEvaluator
//! Evaluate the pushrules for @event .
///
/// You need to have the room_id set for the event.
/// `relatedEvents` is a mapping of rel_type to event. Pass all the events that are related to
/// by this event here.
/// \returns the actions to apply.
[[nodiscard]] std::vector<actions::Action> evaluate(
const mtx::events::collections::TimelineEvent &event,
const RoomContext &ctx) const;
const RoomContext &ctx,
const std::vector<std::pair<mtx::common::Relation, mtx::events::collections::TimelineEvent>>
&relatedEvents) const;

private:
struct OptimizedRules;
Expand Down
3 changes: 2 additions & 1 deletion lib/structs/events/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ from_json(const json &obj, RelationType &type)
type = RelationType::Reference;
else if (obj.get<std::string>() == "m.replace")
type = RelationType::Replace;
else if (obj.get<std::string>() == "im.nheko.relations.v1.in_reply_to")
else if (obj.get<std::string>() == "im.nheko.relations.v1.in_reply_to" ||
obj.get<std::string>() == "m.in_reply_to")
type = RelationType::InReplyTo;
else if (obj.get<std::string>() == "m.thread")
type = RelationType::Thread;
Expand Down
138 changes: 114 additions & 24 deletions lib/structs/pushrules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
#include "mtx/events/collections.hpp"
#include "mtx/log.hpp"

namespace {
struct RelatedEvents
{
std::vector<std::unordered_map<std::string, std::string>>
fallbacks; //!< fallback related events
std::vector<std::unordered_map<std::string, std::string>> events; //!< related events
};
}

namespace mtx {
namespace pushrules {

Expand All @@ -21,15 +30,19 @@ to_json(nlohmann::json &obj, const PushCondition &condition)
obj["pattern"] = condition.pattern;
if (!condition.is.empty())
obj["is"] = condition.is;
if (condition.rel_type != mtx::common::RelationType::Unsupported)
obj["rel_type"] = condition.rel_type;
}

void
from_json(const nlohmann::json &obj, PushCondition &condition)
{
condition.kind = obj["kind"].get<std::string>();
condition.key = obj.value("key", "");
condition.pattern = obj.value("pattern", "");
condition.is = obj.value("is", "");
condition.kind = obj["kind"].get<std::string>();
condition.key = obj.value("key", "");
condition.pattern = obj.value("pattern", "");
condition.is = obj.value("is", "");
condition.rel_type = obj.value("rel_type", mtx::common::RelationType::Unsupported);
condition.include_fallback = obj.value("include_fallback", false);
}

namespace actions {
Expand Down Expand Up @@ -180,11 +193,40 @@ struct PushRuleEvaluator::OptimizedRules
//! a pattern condition to match
struct PatternCondition
{
std::unique_ptr<re2::RE2> pattern; //< the pattern
std::string field; //< the field to match with pattern
std::unique_ptr<re2::RE2> pattern; //!< the pattern
std::string field; //!< the field to match with pattern

bool matches(const std::unordered_map<std::string, std::string> &ev) const
{
if (auto it = ev.find(field); it != ev.end()) {
if (pattern) {
if (field == "content.body") {
if (!re2::RE2::PartialMatch(it->second, *pattern))
return false;
} else {
if (!re2::RE2::FullMatch(it->second, *pattern))
return false;
}
}
} else {
return false;
}

return true;
}
};
// TODO(Nico): Sort by field for faster matching?
std::vector<PatternCondition> patterns; //< conditions that match on a field
std::vector<PatternCondition> patterns; //!< conditions that match on a field

//! a pattern condition to match on a related event
struct RelatedEventCondition
{
PatternCondition ev_match;
mtx::common::RelationType rel_type = mtx::common::RelationType::Unsupported;
bool include_fallbacks = false;
};
std::vector<RelatedEventCondition>
related_event_patterns; //!< conditions that match on fields of the related event.

//! a member count condition
struct MemberCountCondition
Expand Down Expand Up @@ -212,8 +254,10 @@ struct PushRuleEvaluator::OptimizedRules

std::vector<actions::Action> actions; //< the actions to apply on match

[[nodiscard]] bool matches(const std::unordered_map<std::string, std::string> &ev,
const PushRuleEvaluator::RoomContext &ctx) const
[[nodiscard]] bool matches(
const std::unordered_map<std::string, std::string> &ev,
const PushRuleEvaluator::RoomContext &ctx,
const std::map<mtx::common::RelationType, RelatedEvents> &relatedEventsFlat) const
{
for (const auto &cond : membercounts) {
if (![&cond, &ctx] {
Expand Down Expand Up @@ -249,19 +293,34 @@ struct PushRuleEvaluator::OptimizedRules
}

for (const auto &cond : patterns) {
if (auto it = ev.find(cond.field); it != ev.end()) {
if (cond.pattern) {
if (cond.field == "content.body") {
if (!re2::RE2::PartialMatch(it->second, *cond.pattern))
return false;
} else {
if (!re2::RE2::FullMatch(it->second, *cond.pattern))
return false;
if (!cond.matches(ev))
return false;
}

for (const auto &cond : related_event_patterns) {
bool matched = false;
for (const auto &[rel_type, rel_ev] : relatedEventsFlat) {
if (cond.rel_type == rel_type) {
for (const auto &e : rel_ev.events) {
if (cond.ev_match.field.empty() || !cond.ev_match.pattern ||
cond.ev_match.matches(e)) {
matched = true;
break;
}
}
if (cond.include_fallbacks) {
for (const auto &e : rel_ev.fallbacks) {
if (cond.ev_match.field.empty() || !cond.ev_match.pattern ||
cond.ev_match.matches(e)) {
matched = true;
break;
}
}
}
}
} else {
return false;
}
if (!matched)
return false;
}

if (check_displayname) {
Expand Down Expand Up @@ -325,6 +384,23 @@ PushRuleEvaluator::PushRuleEvaluator(const Ruleset &rules_)
c.pattern = construct_re_from_pattern(cond.pattern, cond.key);
if (c.pattern)
rule.patterns.push_back(std::move(c));
} else if (cond.kind == "im.nheko.msc3664.related_event_match") {
OptimizedRules::OptimizedRule::RelatedEventCondition c;

if (cond.rel_type != mtx::common::RelationType::Unsupported) {
c.rel_type = cond.rel_type;
c.include_fallbacks = cond.include_fallback;

if (!cond.key.empty() && !cond.pattern.empty()) {
c.ev_match.field = cond.key;
c.ev_match.pattern = construct_re_from_pattern(cond.pattern, cond.key);
}
rule.related_event_patterns.push_back(std::move(c));
} else {
mtx::utils::log::log()->info(
"Skipping rel_event_match rule with unknown rel_type.");
return false;
}
} else if (cond.kind == "contains_display_name") {
rule.check_displayname = true;
} else if (cond.kind == "room_member_count") {
Expand Down Expand Up @@ -479,19 +555,33 @@ flatten_event(const nlohmann::json &j)
}

std::vector<actions::Action>
PushRuleEvaluator::evaluate(const mtx::events::collections::TimelineEvent &event,
const RoomContext &ctx) const
PushRuleEvaluator::evaluate(
const mtx::events::collections::TimelineEvent &event,
const RoomContext &ctx,
const std::vector<std::pair<mtx::common::Relation, mtx::events::collections::TimelineEvent>>
&relatedEvents) const
{
auto event_json = nlohmann::json(event);
auto flat_event = flatten_event(event_json);

std::map<mtx::common::RelationType, RelatedEvents> relatedEventsFlat;
for (const auto &[rel, ev] : relatedEvents) {
if (rel.rel_type != mtx::common::RelationType::Unsupported) {
if (rel.is_fallback)
relatedEventsFlat[rel.rel_type].fallbacks.push_back(
flatten_event(nlohmann::json(ev)));
else
relatedEventsFlat[rel.rel_type].events.push_back(flatten_event(nlohmann::json(ev)));
}
}

for (const auto &rule : rules->override_) {
if (rule.matches(flat_event, ctx))
if (rule.matches(flat_event, ctx, relatedEventsFlat))
return rule.actions;
}

for (const auto &rule : rules->content) {
if (rule.matches(flat_event, ctx))
if (rule.matches(flat_event, ctx, relatedEventsFlat))
return rule.actions;
}

Expand All @@ -508,7 +598,7 @@ PushRuleEvaluator::evaluate(const mtx::events::collections::TimelineEvent &event
}

for (const auto &rule : rules->underride) {
if (rule.matches(flat_event, ctx))
if (rule.matches(flat_event, ctx, relatedEventsFlat))
return rule.actions;
}
return {};
Expand Down
Loading

0 comments on commit 7155cbb

Please sign in to comment.