Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Convert limit NF4 conversion to FP16 -> NF4 #23806

Merged
merged 8 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/common/transformations/include/ov_ops/gather_compressed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@ class TRANSFORMATIONS_API GatherCompressed : public ov::op::v8::Gather {

GatherCompressed() = default;

GatherCompressed(const ov::Output<Node> &data,
const ov::Output<Node> &indices,
const ov::Output<Node> &axis,
GatherCompressed(const ov::Output<Node>& data,
const ov::Output<Node>& indices,
const ov::Output<Node>& axis,
const int64_t batch_dims,
const ov::Output<Node> &decompression_scale,
const ov::Output<Node> &decompression_zero_point,
const ov::Output<Node>& decompression_scale,
const ov::Output<Node>& decompression_zero_point,
const ov::element::Type output_type = ov::element::undefined);

GatherCompressed(const ov::Output<Node> &data,
const ov::Output<Node> &indices,
const ov::Output<Node> &axis,
GatherCompressed(const ov::Output<Node>& data,
const ov::Output<Node>& indices,
const ov::Output<Node>& axis,
const int64_t batch_dims,
const ov::Output<Node> &decompression_scale,
const ov::Output<Node>& decompression_scale,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;
bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

Expand Down
62 changes: 52 additions & 10 deletions src/core/src/op/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/op/equal.hpp"
#include "openvino/op/select.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/utils/type_util.hpp"

namespace ov {
namespace op {
Expand All @@ -18,10 +19,18 @@ namespace convert {
#define CONVERT_ET_LIST \
boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, nf4, f8e4m3, f8e5m2

#define CONVERT_FROM_NF4_ET_LIST bf16, f16, f32, f64, nf4, f8e4m3, f8e5m2

#define CONVERT_FROM_NOT_FP_ET_LIST \
boolean, bf16, f16, f32, f64, i4, i8, i16, i32, i64, u1, u2, u3, u4, u6, u8, u16, u32, u64, f8e4m3, f8e5m2

struct Evaluate : public element::NoAction<bool> {
using element::NoAction<bool>::visit;

template <element::Type_t ET_IN, class TI = fundamental_type_for<ET_IN>>
// convert from FP (except NF4) to any other.
template <element::Type_t ET_IN,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put a comment in the relevant ticket. In a nutshell, it would be better to support FP16->NF4.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes applied to support FP16 -> NF4. Required to remove some test as they use not supported conversion by convert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can we update the test instead of removing it to make sure that it works on the Python API level which is important for external users such as NNCF?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored tests for python API and CPU. It looks like conversion from NF4 to F16 is required to make possible of decompression of constants (other types by adding additional convert).

But as CPU plugin on not ARM device always converts F16 to F32, I think would be better to add support in Convert to decompress NF4 to F32 directly without adding additional conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a temporary situation. CPU will use BF16, INT8, of FP32 to unpack this type and it should be done on a micro-kernel level (not reference implementation). The reference implementation can have any type. My only concern is to provide an efficient way for compression from FP16/FP32 to NF4. Ideally, we should support both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KodiaqQ mention in ticket that these changes makes no regression.
I think PR can be merged, @AlexKoff88 can you confirm?

class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ov::is_floating_point<TI>()>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
Expand All @@ -33,6 +42,36 @@ struct Evaluate : public element::NoAction<bool> {
count);
}

// convert from integral to any except NF4
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<!ov::is_floating_point<TI>() && ET_IN != element::nf4>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
CONVERT_FROM_NOT_FP_ET_LIST,
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
out,
count);
}

// convert form NF4 to FP
template <element::Type_t ET_IN,
class TI = fundamental_type_for<ET_IN>,
typename std::enable_if<ET_IN == element::nf4>::type* = nullptr>
static result_type visit(const Tensor& arg, Tensor& out, const size_t count) {
using namespace ov::element;
return IF_TYPE_OF(Convert_out,
CONVERT_FROM_NF4_ET_LIST,
EvalByOutputType,
out.get_element_type(),
iterator<ET_IN>(reinterpret_cast<const TI*>(arg.data())),
out,
count);
}

private:
struct EvalByOutputType : public element::NoAction<bool> {
using element::NoAction<bool>::visit;
Expand Down Expand Up @@ -118,16 +157,12 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
if (auto& out = outputs[0]) {
const auto& in = inputs[0];
const auto& in_shape = in.get_shape();
const auto count = shape_size(in_shape);

out.set_shape(in_shape);

using namespace ov::element;
return IF_TYPE_OF(v0_Convert_in_et,
CONVERT_ET_LIST,
convert::Evaluate,
in.get_element_type(),
in,
out,
shape_size(in_shape));
return IF_TYPE_OF(v0_Convert_in_et, CONVERT_ET_LIST, convert::Evaluate, in.get_element_type(), in, out, count);
} else {
return false;
}
Expand All @@ -136,6 +171,10 @@ bool Convert::evaluate(TensorVector& outputs, const TensorVector& inputs) const
bool Convert::has_evaluate() const {
OV_OP_SCOPE(v0_Convert_has_evaluate);

const auto can_nf4_quantize = [](const element::Type& et) {
return et.is_real() || et == element::nf4;
};

const auto is_valid_type = [](const element::Type& et) -> bool {
switch (et) {
case element::boolean:
Expand All @@ -154,7 +193,6 @@ bool Convert::has_evaluate() const {
case element::u16:
case element::u32:
case element::u64:
case element::nf4:
case element::f8e4m3:
case element::f8e5m2:
return true;
Expand All @@ -163,7 +201,11 @@ bool Convert::has_evaluate() const {
};
};

return is_valid_type(get_input_element_type(0)) && is_valid_type(get_output_element_type(0));
const auto& input_et = get_input_element_type(0);
const auto& output_et = get_output_element_type(0);

return (is_valid_type(input_et) && is_valid_type(output_et)) ||
(can_nf4_quantize(input_et) && can_nf4_quantize(output_et));
}

bool Convert::evaluate_lower(TensorVector& output_values) const {
Expand Down
Loading