From 8d674b25bd3d6f4d05293215aae78d2184d3e14e Mon Sep 17 00:00:00 2001 From: Riceball LEE Date: Fri, 26 Jan 2024 16:32:12 +0800 Subject: [PATCH] feat: add new GGUFValueType.OBJ virtual type The content of the OBJ type is actually a list of all key names of the object. * GGUFWriter: * add `def add_kv(self, key: str, val: Any) -> None`: This will be added based on the val type * add `def add_dict(self, key: str, val: dict) -> None`: add object(dict) value * constants: * `GGUFValueType.get_type`: Added support for Numpy's integers and floating-point numbers, and selected the appropriate number of digits based on the size of the integer. * gguf_reader: * add `ReaderField.get`: to return the value of the field * Unit test added. Related Issues: #4868, #2872 --- gguf-py/tests/test_constants.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 gguf-py/tests/test_constants.py diff --git a/gguf-py/tests/test_constants.py b/gguf-py/tests/test_constants.py new file mode 100644 index 00000000000000..a1ba75f925825c --- /dev/null +++ b/gguf-py/tests/test_constants.py @@ -0,0 +1,32 @@ +import sys +from pathlib import Path +import numpy as np +import unittest + +# Necessary to load the local gguf package +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf.constants import GGUFValueType + +class TestGGUFValueType(unittest.TestCase): + + def test_get_type(self): + self.assertEqual(GGUFValueType.get_type("test"), GGUFValueType.STRING) + self.assertEqual(GGUFValueType.get_type([1, 2, 3]), GGUFValueType.ARRAY) + self.assertEqual(GGUFValueType.get_type(1.0), GGUFValueType.FLOAT32) + self.assertEqual(GGUFValueType.get_type(True), GGUFValueType.BOOL) + self.assertEqual(GGUFValueType.get_type(b"test"), GGUFValueType.STRING) + self.assertEqual(GGUFValueType.get_type(np.uint8(1)), GGUFValueType.UINT8) + self.assertEqual(GGUFValueType.get_type(np.uint16(1)), GGUFValueType.UINT16) + self.assertEqual(GGUFValueType.get_type(np.uint32(1)), GGUFValueType.UINT32) + self.assertEqual(GGUFValueType.get_type(np.uint64(1)), GGUFValueType.UINT64) + self.assertEqual(GGUFValueType.get_type(np.int8(-1)), GGUFValueType.INT8) + self.assertEqual(GGUFValueType.get_type(np.int16(-1)), GGUFValueType.INT16) + self.assertEqual(GGUFValueType.get_type(np.int32(-1)), GGUFValueType.INT32) + self.assertEqual(GGUFValueType.get_type(np.int64(-1)), GGUFValueType.INT64) + self.assertEqual(GGUFValueType.get_type(np.float32(1.0)), GGUFValueType.FLOAT32) + self.assertEqual(GGUFValueType.get_type(np.float64(1.0)), GGUFValueType.FLOAT64) + self.assertEqual(GGUFValueType.get_type({"k": 12}), GGUFValueType.OBJ) + +if __name__ == '__main__': + unittest.main()