Skip to content

Commit

Permalink
use pybind operators header to define operator overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Jan 23, 2025
1 parent 3f8e27e commit cac330b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
Expand Down Expand Up @@ -419,9 +420,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); })
.def("inputs", [](migraphx::instruction_ref i) { return i->inputs(); })
.def("name", [](migraphx::instruction_ref i) { return i->name(); })
.def("__hash__", std::hash<migraphx::instruction_ref>{})
.def("__eq__", std::equal_to<migraphx::instruction_ref>{})
.def("__eq__", std::equal_to<py::object>{});
.def(py::hash(py::self))
.def(py::self == py::self)
.def(py::self != py::self);

py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
Expand Down

0 comments on commit cac330b

Please sign in to comment.