From 4b7ec59103153439ac09fbc73a22b85242a45bcd Mon Sep 17 00:00:00 2001
From: wicky <tbc.dengwenqi@gmail.com>
Date: Sun, 5 Apr 2020 22:29:41 +0800
Subject: [PATCH 1/2] Relaxing type requirements for broadcast_like

---
 src/operator/tensor/broadcast_reduce_op_value.cc | 11 ++++++++++-
 tests/python/unittest/test_operator.py           |  9 +++++++++
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 0a14a2008557..71be8f814f3b 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -138,7 +138,16 @@ NNVM_REGISTER_OP(broadcast_like)
     [](const NodeAttrs& attrs) {
       return std::vector<std::string>{"lhs", "rhs"};
     })
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name;
+  std::vector<int> checked_in_attrs = { (*in_attrs)[0] };
+  bool ret = !type_is_none((*in_attrs)[1]) &&
+             ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs);
+  (*in_attrs)[0] = checked_in_attrs[0];
+  return ret;
+})
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::ObjectPtr& n,
     const std::vector<nnvm::NodeEntry>& ograds) {
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 6cbbc5dd0509..6b4c381c3344 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3119,6 +3119,15 @@ def test_reshape_like_different_types():
     z = mx.nd.reshape_like(x, y)
     assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]])
 
+@with_seed()
+def test_broadcast_like_different_types():
+    x = mx.nd.zeros((2, 1))
+    y = mx.nd.ones((2, 2))
+
+    y = mx.nd.array(y).astype('int32')
+    z = mx.nd.broadcast_like(x, y)
+    assert_allclose(z.asnumpy(), [[0,0],[0,0]])
+
 @with_seed()
 def test_flip():
     for ndim in range(1, 6):

From feef46cb8c35685267c02c8be545be298a18e168 Mon Sep 17 00:00:00 2001
From: wicky <tbc.dengwenqi@gmail.com>
Date: Mon, 6 Apr 2020 21:08:06 +0800
Subject: [PATCH 2/2] enhance unit test

---
 tests/python/unittest/test_operator.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 6b4c381c3344..cb835339bee5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3127,6 +3127,7 @@ def test_broadcast_like_different_types():
     y = mx.nd.array(y).astype('int32')
     z = mx.nd.broadcast_like(x, y)
     assert_allclose(z.asnumpy(), [[0,0],[0,0]])
+    assert x.dtype == z.dtype
 
 @with_seed()
 def test_flip():