Skip to content

Commit

Permalink
simplify log of range
Browse files Browse the repository at this point in the history
  • Loading branch information
BiynXu committed Nov 1, 2023
1 parent 4d5fc47 commit b20f611
Showing 1 changed file with 52 additions and 69 deletions.
121 changes: 52 additions & 69 deletions paddle/cinn/hlir/framework/group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,16 @@ void GroupScheduler::BindCudaAxis() {
<< ir_sch_->GetModule().GetExprs().front();
}

struct Range {
int min;
int max;
};

std::ostream& operator<<(std::ostream& os, const Range& x) {
os << "(" << x.min << ", " << x.max << ")";
return os;
}

void GroupScheduler::AllocateStorage() {
if (target_.arch != Target::Arch::NVGPU) return;
VLOG(5) << "[Start AllocateStorage] func body: "
Expand Down Expand Up @@ -688,11 +698,6 @@ void GroupScheduler::AllocateStorage() {
kCudaThreadAndSerial,
};

struct Range {
int min;
int max;
};

// function to calculate the range of the specified CUDA axis in a indice
// expression
auto CalculateRange = [&for_map](ir::Expr indice_value,
Expand Down Expand Up @@ -862,44 +867,32 @@ void GroupScheduler::AllocateStorage() {
store_indice_value, ir::ForType::GPUThread, store_block_name);
auto load_thread_coefficient_and_range = GetCoefficientAndRange(
load_indice_value, ir::ForType::GPUThread, load_block_name);
VLOG(6) << "store_indice_value: " << store_indice_value;
VLOG(6) << "load_indice_value: " << load_indice_value;
VLOG(6) << "store_block_name: " << store_block_name;
VLOG(6) << "load_block_name: " << load_block_name;
VLOG(6) << "store_thread_overall_range = ("
<< store_thread_overall_range.min << ", "
<< store_thread_overall_range.max << ")";
VLOG(6) << "load_thread_overall_range = (" << load_thread_overall_range.min
<< ", " << load_thread_overall_range.max << ")";
VLOG(6) << "store_serial_overall_range = ("
<< store_serial_overall_range.min << ", "
<< store_serial_overall_range.max << ")";
VLOG(6) << "load_serial_overall_range = (" << load_serial_overall_range.min
<< ", " << load_serial_overall_range.max << ")";
VLOG(6) << "store_block_name: " << store_block_name
<< ", load_block_name: " << load_block_name;
VLOG(6) << "store_indice_value: " << store_indice_value
<< ", load_indice_value: " << load_indice_value;
VLOG(6) << "store_thread_overall_range = " << store_thread_overall_range;
VLOG(6) << "load_thread_overall_range = " << load_thread_overall_range;
VLOG(6) << "store_serial_overall_range = " << store_serial_overall_range;
VLOG(6) << "load_serial_overall_range = " << load_serial_overall_range;
VLOG(6) << "store_thread_coefficient_and_range[0] = <"
<< store_thread_coefficient_and_range[0].first << ", ("
<< store_thread_coefficient_and_range[0].second.min << ", "
<< store_thread_coefficient_and_range[0].second.max << ")>";
<< store_thread_coefficient_and_range[0].first << ", "
<< store_thread_coefficient_and_range[0].second << ">";
VLOG(6) << "load_thread_coefficient_and_range[0] = <"
<< load_thread_coefficient_and_range[0].first << ", ("
<< load_thread_coefficient_and_range[0].second.min << ", "
<< load_thread_coefficient_and_range[0].second.max << ")>";
<< load_thread_coefficient_and_range[0].first << ", "
<< load_thread_coefficient_and_range[0].second << ">";
VLOG(6) << "store_thread_coefficient_and_range[1] = <"
<< store_thread_coefficient_and_range[1].first << ", ("
<< store_thread_coefficient_and_range[1].second.min << ", "
<< store_thread_coefficient_and_range[1].second.max << ")>";
<< store_thread_coefficient_and_range[1].first << ", "
<< store_thread_coefficient_and_range[1].second << ">";
VLOG(6) << "load_thread_coefficient_and_range[1] = <"
<< load_thread_coefficient_and_range[1].first << ", ("
<< load_thread_coefficient_and_range[1].second.min << ", "
<< load_thread_coefficient_and_range[1].second.max << ")>";
<< load_thread_coefficient_and_range[1].first << ", "
<< load_thread_coefficient_and_range[1].second << ">";
VLOG(6) << "store_thread_coefficient_and_range[2] = <"
<< store_thread_coefficient_and_range[2].first << ", ("
<< store_thread_coefficient_and_range[2].second.min << ", "
<< store_thread_coefficient_and_range[2].second.max << ")>";
<< store_thread_coefficient_and_range[2].first << ", "
<< store_thread_coefficient_and_range[2].second << ">";
VLOG(6) << "load_thread_coefficient_and_range[2] = <"
<< load_thread_coefficient_and_range[2].first << ", ("
<< load_thread_coefficient_and_range[2].second.min << ", "
<< load_thread_coefficient_and_range[2].second.max << ")>";
<< load_thread_coefficient_and_range[2].first << ", "
<< load_thread_coefficient_and_range[2].second << ">";
return !(store_thread_overall_range.min <= load_thread_overall_range.min &&
store_thread_overall_range.max >= load_thread_overall_range.max &&
store_serial_overall_range.min <= load_serial_overall_range.min &&
Expand Down Expand Up @@ -946,44 +939,34 @@ void GroupScheduler::AllocateStorage() {
store_indice_value, ir::ForType::GPUBlock, store_block_name);
auto load_block_coefficient_and_range = GetCoefficientAndRange(
load_indice_value, ir::ForType::GPUBlock, load_block_name);
VLOG(6) << "store_indice_value: " << store_indice_value;
VLOG(6) << "load_indice_value: " << load_indice_value;
VLOG(6) << "store_block_name: " << store_block_name;
VLOG(6) << "load_block_name: " << load_block_name;
VLOG(6) << "store_block_overall_range = (" << store_block_overall_range.min
<< ", " << store_block_overall_range.max << ")";
VLOG(6) << "load_block_overall_range = (" << load_block_overall_range.min
<< ", " << load_block_overall_range.max << ")";
VLOG(6) << "store_thread_and_serial_overall_range = ("
<< store_thread_and_serial_overall_range.min << ", "
<< store_thread_and_serial_overall_range.max << ")";
VLOG(6) << "load_thread_and_serial_overall_range = ("
<< load_thread_and_serial_overall_range.min << ", "
<< load_thread_and_serial_overall_range.max << ")";
VLOG(6) << "store_block_name: " << store_block_name
<< ", load_block_name: " << load_block_name;
VLOG(6) << "store_indice_value: " << store_indice_value
<< ", load_indice_value: " << load_indice_value;
VLOG(6) << "store_block_overall_range = " << store_block_overall_range;
VLOG(6) << "load_block_overall_range = " << load_block_overall_range;
VLOG(6) << "store_thread_and_serial_overall_range = "
<< store_thread_and_serial_overall_range;
VLOG(6) << "load_thread_and_serial_overall_range = "
<< load_thread_and_serial_overall_range;
VLOG(6) << "store_block_coefficient_and_range[0] = <"
<< store_block_coefficient_and_range[0].first << ", ("
<< store_block_coefficient_and_range[0].second.min << ", "
<< store_block_coefficient_and_range[0].second.max << ")>";
<< store_block_coefficient_and_range[0].first << ", "
<< store_block_coefficient_and_range[0].second << ">";
VLOG(6) << "load_block_coefficient_and_range[0] = <"
<< load_block_coefficient_and_range[0].first << ", ("
<< load_block_coefficient_and_range[0].second.min << ", "
<< load_block_coefficient_and_range[0].second.max << ")>";
<< load_block_coefficient_and_range[0].first << ", "
<< load_block_coefficient_and_range[0].second << ">";
VLOG(6) << "store_block_coefficient_and_range[1] = <"
<< store_block_coefficient_and_range[1].first << ", ("
<< store_block_coefficient_and_range[1].second.min << ", "
<< store_block_coefficient_and_range[1].second.max << ")>";
<< store_block_coefficient_and_range[1].first << ", "
<< store_block_coefficient_and_range[1].second << ">";
VLOG(6) << "load_block_coefficient_and_range[1] = <"
<< load_block_coefficient_and_range[1].first << ", ("
<< load_block_coefficient_and_range[1].second.min << ", "
<< load_block_coefficient_and_range[1].second.max << ")>";
<< load_block_coefficient_and_range[1].first << ", "
<< load_block_coefficient_and_range[1].second << ">";
VLOG(6) << "store_block_coefficient_and_range[2] = <"
<< store_block_coefficient_and_range[2].first << ", ("
<< store_block_coefficient_and_range[2].second.min << ", "
<< store_block_coefficient_and_range[2].second.max << ")>";
<< store_block_coefficient_and_range[2].first << ", "
<< store_block_coefficient_and_range[2].second << ">";
VLOG(6) << "load_block_coefficient_and_range[2] = <"
<< load_block_coefficient_and_range[2].first << ", ("
<< load_block_coefficient_and_range[2].second.min << ", "
<< load_block_coefficient_and_range[2].second.max << ")>";
<< load_block_coefficient_and_range[2].first << ", "
<< load_block_coefficient_and_range[2].second << ">";
return !(store_block_overall_range.min <= load_block_overall_range.min &&
store_block_overall_range.max >= load_block_overall_range.max &&
store_thread_and_serial_overall_range.min <=
Expand Down

0 comments on commit b20f611

Please sign in to comment.