-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
static graph autogen code support for matmul op #54338
static graph autogen code support for matmul op #54338
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
把自动生成的组合算子部分,替换成原来的,单测就能通过
class MatmulV2GradCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
//get inputs
auto x = this->GetSingleForwardInput("X");
auto y = this->GetSingleForwardInput("Y");
auto grad_out = this->GetSingleForwardInput("grad_out");
auto grad_x_grad = this->GetOptionalSingleOutputGrad("grad_x");
auto grad_y_grad = this->GetOptionalSingleOutputGrad("grad_y");
//get attr
const bool transpose_x = this->Attr<bool>("trans_x");
const bool transpose_y = this->Attr<bool>("trans_y");
//get output
auto x_grad_t = this->GetSingleInputGrad("X");
auto y_grad_t = this->GetSingleInputGrad("Y");
auto grad_out_grad_t = this->GetSingleInputGrad("grad_out");
//get output ptr
auto x_grad = this->GetOutputPtr(&x_grad_t);
auto y_grad = this->GetOutputPtr(&y_grad_t);
auto grad_out_grad = this->GetOutputPtr(&grad_out_grad_t);
//get output orginal name
auto x_grad_name = this->GetOutputName(x_grad_t);
auto y_grad_name = this->GetOutputName(y_grad_t);
auto grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
//call composite backward func
VLOG(6) << "Runing matmul_double_grad composite func";
prim::matmul_double_grad<prim::DescTensor>(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x, transpose_y, x_grad, y_grad, grad_out_grad);
//recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
this->RecoverOutputName(y_grad_t, y_grad_name);
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
}
};
class MatmulV2GradCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
// get inputs
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor dout =
this->GetSingleForwardInput(framework::GradVarName("Out"));
paddle::optional<paddle::Tensor> ddx =
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
paddle::optional<paddle::Tensor> ddy =
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));
// get attr
bool trans_x = this->Attr<bool>("trans_x");
bool trans_y = this->Attr<bool>("trans_y");
// get output
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
paddle::Tensor y_grad_t = this->GetSingleInputGrad("Y");
paddle::Tensor grad_out_grad_t =
this->GetSingleInputGrad(framework::GradVarName("Out"));
// get output ptr
paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
paddle::Tensor* y_grad = this->GetOutputPtr(&y_grad_t);
paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t);
// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
std::string y_grad_name = this->GetOutputName(y_grad_t);
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
VLOG(3) << "Runing matmul_double_grad composite func";
// call composite backward func
prim::matmul_double_grad<prim::DescTensor>(
x, y, dout, ddx, ddy, trans_x, trans_y, x_grad, y_grad, grad_out_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
this->RecoverOutputName(y_grad_t, y_grad_name);
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
}
}; |
不知道是哪出了问题 |
组合算子的高阶反向问题已经反馈给相关人员,预计下周可以解决 |
@heavyrain-lzy 麻烦再次review一下。 |
目前看应该没什么问题,请rerun一下coverage |
@heavyrain-lzy coverage 我 rerun 不了,没有权限。 |
没有找到 |
@heavyrain-lzy 这个也麻烦再次review一下 |
目前已经把 |
coverage 跑完了,主要是 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
static graph autogen code support for matmul op