Skip to content

Commit

Permalink
fix xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Jan 25, 2022
1 parent 17933cb commit 2e61e81
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,

kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_);

pt_kernel_name = pt_kernel_signature_->name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
Expand Down Expand Up @@ -1205,6 +1206,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (!run_pten_kernel_) {
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(exe_ctx);
dev_ctx = pool.Get(kernel_type_->place_);
}
}

Expand All @@ -1223,10 +1225,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
const Scope& exec_scope =
(transfer_scope == nullptr ? scope : *transfer_scope);

if (!run_pten_kernel_ && !(kernel_type_->place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(kernel_type_->place_);
}

if (!all_kernels_must_compute_runtime_shape_) {
platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp);
Expand Down

0 comments on commit 2e61e81

Please sign in to comment.