Skip to content

Commit

Permalink
Reduce NF4 conversion only from FP16
Browse files Browse the repository at this point in the history
in convert operator
  • Loading branch information
praasz committed Apr 3, 2024
1 parent 89a463e commit f8ce4eb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 71 deletions.

This file was deleted.

27 changes: 12 additions & 15 deletions src/core/src/op/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,52 +19,50 @@ namespace convert {
#define CONVERT_ET_LIST \
boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, nf4, f8e4m3, f8e5m2

#define CONVERT_FROM_NF4_ET_LIST bf16, f16, f32, f64, nf4, f8e4m3, f8e5m2

#define CONVERT_FROM_NOT_FP_ET_LIST \
#define CONVERT_TO_ANY_NO_NF4 \
boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, f8e4m3, f8e5m2

struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;

// convert from FP (except NF4) to any other.
// convert from any (except F16, NF4) to any except NF4
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ov::is_floating_point<TI>()>::type* = nullptr>
typename std::enable_if<ET_IN != element::f16 && ET_IN != element::nf4>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
CONVERT_ET_LIST,
CONVERT_TO_ANY_NO_NF4,
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
out,
count);
}

// convert from integral to any except NF4
// convert from F16 to any
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<!ov::is_floating_point<TI>() && ET_IN != element::nf4>::type* = nullptr>
typename std::enable_if<ET_IN == element::f16>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
CONVERT_FROM_NOT_FP_ET_LIST,
CONVERT_ET_LIST,
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
out,
count);
}

// convert form NF4 to FP
// convert form NF4 to NF4
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ET_IN == element::nf4>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
CONVERT_FROM_NF4_ET_LIST,
OV_PP_ET_LIST(nf4),
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
Expand Down Expand Up @@ -171,8 +169,8 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
bool Convert::has_evaluate() const {
OV_OP_SCOPE(v0_Convert_has_evaluate);

const auto can_nf4_quantize = [](const element::Type& et) {
return et.is_real() || et == element::nf4;
const auto is_to_nf4_supported = [](const element::Type& from, const element::Type& to) {
return (from == element::f16 || from == element::nf4) && (to == element::nf4);
};

const auto is_valid_type = [](const element::Type& et) -> bool {
Expand Down Expand Up @@ -204,8 +202,7 @@ bool Convert::has_evaluate() const {
const auto& input_et = get_input_element_type(0);
const auto& output_et = get_output_element_type(0);

return (is_valid_type(input_et) && is_valid_type(output_et)) ||
(can_nf4_quantize(input_et) && can_nf4_quantize(output_et));
return (is_valid_type(input_et) && is_valid_type(output_et)) || is_to_nf4_supported(input_et, output_et);
}

bool Convert::evaluate_lower(TensorVector& output_values) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ std::vector<ov::AnyMap> filter_additional_config_amx() {
}

const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32};
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8,
ov::element::u4,
ov::element::i4,
ov::element::nf4};
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8, ov::element::u4, ov::element::i4};

const std::vector<ShapeParams> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1612,9 +1612,17 @@ INSTANTIATE_TEST_SUITE_P(
// destination nf4 (use quantization)
ConvertParams(ConversionTypes::CONVERT_LIKE,
ov::PartialShape{4},
ov::element::f32,
ov::element::f16,
ov::element::nf4,
std::vector<float16>{-0.6961928009986877f, 0.7229568362236023f, 1.0f, -0.5250730514526367f},
std::vector<uint8_t>{0xE1, 0x2F},
4,
4),
ConvertParams(ConversionTypes::CONVERT_LIKE,
ov::PartialShape{4},
ov::element::nf4,
ov::element::nf4,
std::vector<float>{-0.6961928009986877f, 0.7229568362236023f, 1.0f, -0.5250730514526367f},
std::vector<uint8_t>{0xE1, 0x2f},
std::vector<uint8_t>{0xE1, 0x2F},
4,
4),
Expand Down

0 comments on commit f8ce4eb

Please sign in to comment.