Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Refactor IR Printer: add object_path(from tvm) #183

Merged
merged 1 commit into from
Mar 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 313 additions & 0 deletions include/matxscript/ir/_base/object_path.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
// Copyright 2022 ByteDance Ltd. and/or its affiliates.
/*
* Acknowledgement: This file originates from TVM.
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file matx/ir/_base/object_path.h
* ObjectPath class that represents a path from a root object to one of its descendants
* via attribute access, array indexing etc.
*/

#pragma once

#include <string>

#include <matxscript/runtime/object.h>

#include <matxscript/ir/_base/optional_ref.h>
#include <matxscript/ir/_base/string_ref.h>

namespace matxscript {
namespace ir {

using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;

class ObjectPath;

/*!
* \brief Path to an object from some root object.
*
* Motivation:
*
* Same IR node object can be referenced in several different contexts inside a larger IR object.
* For example, a variable could be referenced in several statements within a block.
*
* This makes it impossible to use an object pointer to uniquely identify a "location" within
* the larger IR object for error reporting purposes. The ObjectPath class addresses this problem
* by serving as a unique "locator".
*/
class ObjectPathNode : public Object {
public:
/*! \brief Get the parent path */
Optional<ObjectPath> GetParent() const;
/*!
* \brief Get the length of the path.
*
* For example, the path returned by `ObjectPath::Root()` has length 1.
*/
int32_t Length() const;

/*!
* \brief Get a path prefix of the given length.
*
* Provided `length` must not exceed the `Length()` of this path.
*/
ObjectPath GetPrefix(int32_t length) const;

/*!
* \brief Check if this path is a prefix of another path.
*
* The prefix is not strict, i.e. a path is considered a prefix of itself.
*/
bool IsPrefixOf(const ObjectPath& other) const;

/*! \brief Check if two paths are equal. */
bool PathsEqual(const ObjectPath& other) const;

/*! \brief Extend this path with access to an object attribute. */
ObjectPath Attr(const char* attr_key) const;

/*! \brief Extend this path with access to an object attribute. */
ObjectPath Attr(Optional<StringRef> attr_key) const;

/*! \brief Extend this path with access to an array element. */
ObjectPath ArrayIndex(int32_t index) const;

/*! \brief Extend this path with access to a missing array element. */
ObjectPath MissingArrayElement(int32_t index) const;

/*! \brief Extend this path with access to a map value. */
ObjectPath MapValue(ObjectRef key) const;

/*! \brief Extend this path with access to a missing map entry. */
ObjectPath MissingMapEntry() const;

static constexpr const char* _type_key = "ObjectPath";
MATXSCRIPT_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object);

protected:
explicit ObjectPathNode(const ObjectPathNode* parent);

friend class ObjectPath;
friend runtime::String GetObjectPathRepr(const ObjectPathNode* node);

const ObjectPathNode* ParentNode() const;

/*! Compares just the last node of the path, without comparing the whole path. */
virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0;

virtual runtime::String LastNodeString() const = 0;

private:
Optional<ObjectRef> parent_;
int32_t length_;
};

class ObjectPath : public ObjectRef {
public:
/*! \brief Create a path that represents the root object itself. */
static ObjectPath Root();

MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode);
};

//-------------------------------------------------------------------------
//----- Concrete object path nodes ------------------------------------
//-------------------------------------------------------------------------

// ----- Root -----

class RootPathNode final : public ObjectPathNode {
public:
explicit RootPathNode();

static constexpr const char* _type_key = "RootPath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class RootPath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode);
};

// ----- Attribute access -----

class AttributeAccessPathNode final : public ObjectPathNode {
public:
/*! \brief Name of the attribute being accessed. Must be a static string. */
StringRef attr_key;

explicit AttributeAccessPathNode(const ObjectPathNode* parent, StringRef attr_key);

static constexpr const char* _type_key = "AttributeAccessPath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class AttributeAccessPath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttributeAccessPath,
ObjectPath,
AttributeAccessPathNode);
};

// ----- Unknown attribute access -----

class UnknownAttributeAccessPathNode final : public ObjectPathNode {
public:
explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent);

static constexpr const char* _type_key = "UnknownAttributeAccessPath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class UnknownAttributeAccessPath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnknownAttributeAccessPath,
ObjectPath,
UnknownAttributeAccessPathNode);
};

// ----- Array element access by index -----

class ArrayIndexPathNode : public ObjectPathNode {
public:
/*! \brief Index of the array element that is being accessed. */
int32_t index;

explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index);

static constexpr const char* _type_key = "ArrayIndexPath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class ArrayIndexPath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode);
};

// ----- Missing array element -----

class MissingArrayElementPathNode : public ObjectPathNode {
public:
/*! \brief Index of the array element that is missing. */
int32_t index;

explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index);

static constexpr const char* _type_key = "MissingArrayElementPath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class MissingArrayElementPath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingArrayElementPath,
ObjectPath,
MissingArrayElementPathNode);
};

// ----- Map value -----

class MapValuePathNode : public ObjectPathNode {
public:
/*! \brief Key of the map entry that is being accessed */
ObjectRef key;

explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key);

static constexpr const char* _type_key = "MapValuePath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class MapValuePath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode);
};

// ----- Missing map entry -----

class MissingMapEntryPathNode : public ObjectPathNode {
public:
explicit MissingMapEntryPathNode(const ObjectPathNode* parent);

static constexpr const char* _type_key = "MissingMapEntryPath";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode);

protected:
bool LastNodeEqual(const ObjectPathNode* other) const final;
runtime::String LastNodeString() const final;
};

class MissingMapEntryPath : public ObjectPath {
public:
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingMapEntryPath,
ObjectPath,
MissingMapEntryPathNode);
};

/*!
* \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
*/
class ObjectPathPairNode : public Object {
public:
ObjectPath lhs_path;
ObjectPath rhs_path;

ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);

static constexpr const char* _type_key = "ObjectPathPair";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
};

class ObjectPathPair : public ObjectRef {
public:
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);

MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
};

} // namespace ir
} // namespace matxscript
22 changes: 22 additions & 0 deletions python/matx/ir/_ffi_node_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for matx.ir"""
from .. import _ffi

_ffi._init_api("node", __name__)
146 changes: 146 additions & 0 deletions python/matx/ir/object_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Acknowledgement: This file originates from TVM.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
ObjectPath class that represents a path from a root object to one of its descendants
via attribute access, array indexing etc.
"""

from .. import _ffi
from ..runtime import Object
from . import _ffi_node_api

__all__ = (
"ObjectPath",
"RootPath",
"AttributeAccessPath",
"UnknownAttributeAccessPath",
"ArrayIndexPath",
"MissingArrayElementPath",
"MapValuePath",
"MissingMapEntryPath",
"ObjectPathPair",
)


@_ffi.register_object("ObjectPath")
class ObjectPath(Object):
"""
Path to an object from some root object.
"""

def __init__(self) -> None:
super().__init__()
raise ValueError(
"ObjectPath can't be initialized directly. "
"Use ObjectPath.root() to create a path to the root object"
)

@staticmethod
def root() -> "ObjectPath":
return _ffi_node_api.ObjectPathRoot()

def __eq__(self, other):
return _ffi_node_api.ObjectPathEqual(self, other)

def __ne__(self, other):
return not _ffi_node_api.ObjectPathEqual(self, other)

@property
def parent(self) -> "ObjectPath":
return _ffi_node_api.ObjectPathGetParent(self)

def __len__(self) -> int:
return _ffi_node_api.ObjectPathLength(self)

def get_prefix(self, length) -> "ObjectPath":
return _ffi_node_api.ObjectPathGetPrefix(self, length)

def is_prefix_of(self, other) -> "ObjectPath":
return _ffi_node_api.ObjectPathIsPrefixOf(self, other)

def attr(self, attr_key) -> "ObjectPath":
return _ffi_node_api.ObjectPathAttr(self, attr_key)

def array_index(self, index) -> "ObjectPath":
return _ffi_node_api.ObjectPathArrayIndex(self, index)

def missing_array_element(self, index) -> "ObjectPath":
return _ffi_node_api.ObjectPathMissingArrayElement(self, index)

def map_value(self, key) -> "ObjectPath":
from ._converter import convert
return _ffi_node_api.ObjectPathMapValue(self, convert(key))

def missing_map_entry(self) -> "ObjectPath":
return _ffi_node_api.ObjectPathMissingMapEntry(self)

__hash__ = Object.__hash__


@_ffi.register_object("RootPath")
class RootPath(ObjectPath):
pass


@_ffi.register_object("AttributeAccessPath")
class AttributeAccessPath(ObjectPath):
pass


@_ffi.register_object("UnknownAttributeAccessPath")
class UnknownAttributeAccessPath(ObjectPath):
pass


@_ffi.register_object("ArrayIndexPath")
class ArrayIndexPath(ObjectPath):
pass


@_ffi.register_object("MissingArrayElementPath")
class MissingArrayElementPath(ObjectPath):
pass


@_ffi.register_object("MapValuePath")
class MapValuePath(ObjectPath):
pass


@_ffi.register_object("MissingMapEntryPath")
class MissingMapEntryPath(ObjectPath):
pass


@_ffi.register_object("ObjectPathPair")
class ObjectPathPair(Object):
"""
Pair of ObjectPaths, one for each object being tested for structural equality.
"""

@property
def lhs_path(self) -> ObjectPath:
return _ffi_node_api.ObjectPathPairLhsPath(self)

@property
def rhs_path(self) -> ObjectPath:
return _ffi_node_api.ObjectPathPairRhsPath(self)
383 changes: 383 additions & 0 deletions src/ir/_base/object_path.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,383 @@
// Copyright 2022 ByteDance Ltd. and/or its affiliates.
/*
* Acknowledgement: This file originates from TVM.
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <matxscript/ir/_base/object_path.h>

#include <algorithm>
#include <cstring>

#include <matxscript/ir/_base/repr_printer.h>
#include <matxscript/runtime/memory.h>
#include <matxscript/runtime/registry.h>
#include "matxscript/ir/_base/reflection.h"

using namespace matxscript::runtime;

namespace matxscript {
namespace ir {

// ============== ObjectPathNode ==============

ObjectPathNode::ObjectPathNode(const ObjectPathNode* parent)
: parent_(GetRef<ObjectRef>(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {
}

// --- GetParent ---

Optional<ObjectPath> ObjectPathNode::GetParent() const {
if (parent_ == nullptr) {
return NullOpt;
} else {
return Downcast<ObjectPath>(parent_);
}
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_typed([](const ObjectPath& self) {
return self->GetParent();
});

// --- Length ---

int32_t ObjectPathNode::Length() const {
return length_;
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathLength").set_body_typed([](const ObjectPath& self) {
return self->Length();
});

// --- GetPrefix ---

ObjectPath ObjectPathNode::GetPrefix(int32_t length) const {
MXCHECK_GE(length, 1) << "IndexError: Prefix length must be at least 1";
MXCHECK_LE(length, Length())
<< "IndexError: Attempted to get a prefix longer than the path itself";

const ObjectPathNode* node = this;
int32_t suffix_len = Length() - length;
for (int32_t i = 0; i < suffix_len; ++i) {
node = node->ParentNode();
}

return GetRef<ObjectPath>(node);
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathGetPrefix")
.set_body_typed([](const ObjectPath& self, int32_t length) { return self->GetPrefix(length); });

// --- IsPrefixOf ---

bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const {
int32_t this_len = Length();
if (this_len > other->Length()) {
return false;
}
return this->PathsEqual(other->GetPrefix(this_len));
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf")
.set_body_typed([](const ObjectPath& self, const ObjectPath& other) {
return self->IsPrefixOf(other);
});

// --- Attr ---

ObjectPath ObjectPathNode::Attr(const char* attr_key) const {
if (attr_key != nullptr) {
return ObjectPath(make_object<AttributeAccessPathNode>(this, attr_key));
} else {
return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this));
}
}

ObjectPath ObjectPathNode::Attr(Optional<StringRef> attr_key) const {
if (attr_key.defined()) {
return ObjectPath(make_object<AttributeAccessPathNode>(this, attr_key.value()));
} else {
return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this));
}
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathAttr")
.set_body_typed([](const ObjectPath& object_path, const Optional<StringRef>& attr_key) {
return object_path->Attr(attr_key);
});

// --- ArrayIndex ---

ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const {
return ObjectPath(make_object<ArrayIndexPathNode>(this, index));
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathArrayIndex")
.set_body_typed([](const ObjectPath& self, int32_t index) { return self->ArrayIndex(index); });

// --- MissingArrayElement ---

ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const {
return ObjectPath(make_object<MissingArrayElementPathNode>(this, index));
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement")
.set_body_typed([](const ObjectPath& self, int32_t index) {
return self->MissingArrayElement(index);
});

// --- MapValue ---

ObjectPath ObjectPathNode::MapValue(ObjectRef key) const {
return ObjectPath(make_object<MapValuePathNode>(this, std::move(key)));
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathMapValue")
.set_body_typed([](const ObjectPath& self, const ObjectRef& key) {
return self->MapValue(key);
});

// --- MissingMapEntry ---

ObjectPath ObjectPathNode::MissingMapEntry() const {
return ObjectPath(make_object<MissingMapEntryPathNode>(this));
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry")
.set_body_typed([](const ObjectPath& self) { return self->MissingMapEntry(); });

// --- PathsEqual ----

bool ObjectPathNode::PathsEqual(const ObjectPath& other) const {
if (!other.defined() || Length() != other->Length()) {
return false;
}

const ObjectPathNode* lhs = this;
const ObjectPathNode* rhs = static_cast<const ObjectPathNode*>(other.get());

while (lhs != nullptr && rhs != nullptr) {
if (lhs->type_index() != rhs->type_index()) {
return false;
}
if (!lhs->LastNodeEqual(rhs)) {
return false;
}
lhs = lhs->ParentNode();
rhs = rhs->ParentNode();
}

return lhs == nullptr && rhs == nullptr;
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathEqual")
.set_body_typed([](const ObjectPath& self, const ObjectPath& other) {
return self->PathsEqual(other);
});

// --- Repr ---

runtime::String GetObjectPathRepr(const ObjectPathNode* node) {
runtime::String ret;
while (node != nullptr) {
runtime::String node_str = node->LastNodeString();
ret.append(node_str.rbegin(), node_str.rend());
node = static_cast<const ObjectPathNode*>(node->GetParent().get());
}
std::reverse(ret.begin(), ret.end());
return ret;
}

static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) {
p->stream << GetObjectPathRepr(static_cast<const ObjectPathNode*>(node.get()));
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(ObjectPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ObjectPathNode>(PrintObjectPathRepr);

// --- Private/protected methods ---

const ObjectPathNode* ObjectPathNode::ParentNode() const {
return static_cast<const ObjectPathNode*>(parent_.get());
}

// ============== ObjectPath ==============

/* static */ ObjectPath ObjectPath::Root() {
return ObjectPath(make_object<RootPathNode>());
}

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);

// ============== Individual path classes ==============

// ----- Root -----

RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {
}

bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const {
return true;
}

runtime::String RootPathNode::LastNodeString() const {
return "<root>";
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(RootPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr);

// ----- AttributeAccess -----

AttributeAccessPathNode::AttributeAccessPathNode(const ObjectPathNode* parent, StringRef attr_key)
: ObjectPathNode(parent), attr_key(std::move(attr_key)) {
}

bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherAttrAccess = static_cast<const AttributeAccessPathNode*>(other);
return attr_key == otherAttrAccess->attr_key;
}

runtime::String AttributeAccessPathNode::LastNodeString() const {
return "." + attr_key;
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(AttributeAccessPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttributeAccessPathNode>(PrintObjectPathRepr);

// ----- UnknownAttributeAccess -----

UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(const ObjectPathNode* parent)
: ObjectPathNode(parent) {
}

bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const {
// Consider any two unknown attribute accesses unequal
return false;
}

runtime::String UnknownAttributeAccessPathNode::LastNodeString() const {
return ".<unknown attribute>";
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(UnknownAttributeAccessPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<UnknownAttributeAccessPathNode>(PrintObjectPathRepr);

// ----- ArrayIndexPath -----

ArrayIndexPathNode::ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index)
: ObjectPathNode(parent), index(index) {
}

bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherArrayIndex = static_cast<const ArrayIndexPathNode*>(other);
return index == otherArrayIndex->index;
}

runtime::String ArrayIndexPathNode::LastNodeString() const {
return "[" + std::to_string(index) + "]";
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(ArrayIndexPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayIndexPathNode>(PrintObjectPathRepr);

// ----- MissingArrayElement -----

MissingArrayElementPathNode::MissingArrayElementPathNode(const ObjectPathNode* parent,
int32_t index)
: ObjectPathNode(parent), index(index) {
}

bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherMissingElement = static_cast<const MissingArrayElementPathNode*>(other);
return index == otherMissingElement->index;
}

runtime::String MissingArrayElementPathNode::LastNodeString() const {
return "[<missing element #" + std::to_string(index) + ">]";
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(MissingArrayElementPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MissingArrayElementPathNode>(PrintObjectPathRepr);

// ----- MapValue -----

MapValuePathNode::MapValuePathNode(const ObjectPathNode* parent, ObjectRef key)
: ObjectPathNode(parent), key(std::move(key)) {
}

bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const {
const auto* otherMapValue = static_cast<const MapValuePathNode*>(other);
return ObjectEqual()(key, otherMapValue->key);
}

runtime::String MapValuePathNode::LastNodeString() const {
std::ostringstream s;
s << "[" << key << "]";
return s.str();
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(MapValuePathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapValuePathNode>(PrintObjectPathRepr);

// ----- MissingMapEntry -----

MissingMapEntryPathNode::MissingMapEntryPathNode(const ObjectPathNode* parent)
: ObjectPathNode(parent) {
}

bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const {
return true;
}

runtime::String MissingMapEntryPathNode::LastNodeString() const {
return "[<missing entry>]";
}

MATXSCRIPT_REGISTER_OBJECT_TYPE(MissingMapEntryPathNode);
MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MissingMapEntryPathNode>(PrintObjectPathRepr);

MATXSCRIPT_REGISTER_OBJECT_TYPE(ObjectPathPairNode);

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathPairLhsPath")
.set_body_typed([](const ObjectPathPair& object_path_pair) {
return object_path_pair->lhs_path;
});

MATXSCRIPT_REGISTER_GLOBAL("node.ObjectPathPairRhsPath")
.set_body_typed([](const ObjectPathPair& object_path_pair) {
return object_path_pair->rhs_path;
});

ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
: lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {
}

ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) {
data_ = make_object<ObjectPathPairNode>(std::move(lhs_path), std::move(rhs_path));
}

} // namespace ir
} // namespace matxscript
147 changes: 147 additions & 0 deletions test/ir/test_object_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Acknowledgement: This file originates from TVM.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import unittest
import matx

from matx.ir import object_path
from matx.ir.object_path import ObjectPath


class TestIRObjectPath(unittest.TestCase):

def setUp(self) -> None:
pass

def test_root_path(self):
root = ObjectPath.root()
assert isinstance(root, object_path.RootPath)
assert str(root) == "<root>"
assert len(root) == 1
assert root == ObjectPath.root()
assert root.parent is None

def test_path_attr(self):
path = ObjectPath.root().attr("foo")
assert isinstance(path, object_path.AttributeAccessPath)
assert str(path) == "<root>.foo"
assert len(path) == 2
assert path.parent == ObjectPath.root()

def test_path_attr_unknown(self):
path = ObjectPath.root().attr(None)
assert isinstance(path, object_path.UnknownAttributeAccessPath)
assert str(path) == "<root>.<unknown attribute>"
assert len(path) == 2
assert path.parent == ObjectPath.root()

def test_path_array_index(self):
path = ObjectPath.root().array_index(2)
assert isinstance(path, object_path.ArrayIndexPath)
assert str(path) == "<root>[2]"
assert len(path) == 2
assert path.parent == ObjectPath.root()

def test_path_missing_array_element(self):
path = ObjectPath.root().missing_array_element(2)
assert isinstance(path, object_path.MissingArrayElementPath)
assert str(path) == "<root>[<missing element #2>]"
assert len(path) == 2
assert path.parent == ObjectPath.root()

def test_path_map_value(self):
path = ObjectPath.root().map_value("foo")
assert isinstance(path, object_path.MapValuePath)
assert str(path) == '<root>["foo"]'
assert len(path) == 2
assert path.parent == ObjectPath.root()

def test_path_missing_map_entry(self):
path = ObjectPath.root().missing_map_entry()
assert isinstance(path, object_path.MissingMapEntryPath)
assert str(path) == "<root>[<missing entry>]"
assert len(path) == 2
assert path.parent == ObjectPath.root()

def test_path_is_prefix_of(self):
parametrizes = [
(ObjectPath.root(), ObjectPath.root(), True),
(ObjectPath.root(), ObjectPath.root().attr("foo"), True),
(ObjectPath.root().attr("foo"), ObjectPath.root(), False),
(ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo"), True),
(ObjectPath.root().attr("bar"), ObjectPath.root().attr("foo"), False),
(ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo").array_index(2), True),
(ObjectPath.root().attr("foo").array_index(2), ObjectPath.root().attr("foo"), False),
(ObjectPath.root().attr("foo"), ObjectPath.root().attr("bar").array_index(2), False),
]

def test_func(a, b, expected):
assert a.is_prefix_of(b) == expected

for a, b, expected in parametrizes:
test_func(a, b, expected)

def test_path_equal(self):
paths_for_equality_test = [
ObjectPath.root(),
ObjectPath.root().attr("foo"),
ObjectPath.root().attr("bar"),
ObjectPath.root().array_index(3),
ObjectPath.root().array_index(4),
ObjectPath.root().missing_array_element(3),
ObjectPath.root().missing_array_element(4),
ObjectPath.root().map_value("foo"),
ObjectPath.root().map_value("bar"),
ObjectPath.root().missing_map_entry(),
ObjectPath.root().attr("foo").missing_map_entry(),
]

def test_path_equal_impl(a_idx, a_path, b_idx, b_path):
expected = a_idx == b_idx
result = a_path == b_path
assert result == expected

for idx, path in enumerate(paths_for_equality_test):
test_path_equal_impl(idx, path, idx, path)

def test_path_get_prefix(self):
p1 = ObjectPath.root()
p2 = p1.attr("foo")
p3 = p2.array_index(5)

assert p3.parent == p2
assert p2.parent == p1
assert p1.parent is None

assert p2.get_prefix(1) == p1

assert p3.get_prefix(1) == p1
assert p3.get_prefix(2) == p2
assert p3.get_prefix(3) == p3

with self.assertRaises(IndexError) as e:
p3.get_prefix(0)
assert "Prefix length must be at least 1" in str(e.exception)

with self.assertRaises(IndexError) as e:
p3.get_prefix(4)
assert "Attempted to get a prefix longer than the path itself" in str(e.exception)