Skip to content

Commit

Permalink
feat: add new GGUFValueType.OBJ virtual type
Browse files Browse the repository at this point in the history
The content of the OBJ type is actually a list of all key names of the object.

* GGUFWriter:
  * add `def add_kv(self, key: str, val: Any) -> None`:  This will be added based on the val type
  * add `def add_dict(self, key: str, val: dict) -> None`: add object(dict) value
* constants:
  * `GGUFValueType.get_type`: Added support for Numpy's integers and floating-point numbers, and selected the appropriate number of digits based on the size of the integer.
* gguf_reader:
  * add `ReaderField.get`: to return the value of the field
* Unit test added.

Related Issues: ggml-org#4868, ggml-org#2872
  • Loading branch information
snowyu committed Jan 26, 2024
1 parent c25052b commit 8d674b2
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions gguf-py/tests/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
from pathlib import Path
import numpy as np
import unittest

# Necessary to load the local gguf package
sys.path.insert(0, str(Path(__file__).parent.parent))

from gguf.constants import GGUFValueType

class TestGGUFValueType(unittest.TestCase):

def test_get_type(self):
self.assertEqual(GGUFValueType.get_type("test"), GGUFValueType.STRING)
self.assertEqual(GGUFValueType.get_type([1, 2, 3]), GGUFValueType.ARRAY)
self.assertEqual(GGUFValueType.get_type(1.0), GGUFValueType.FLOAT32)
self.assertEqual(GGUFValueType.get_type(True), GGUFValueType.BOOL)
self.assertEqual(GGUFValueType.get_type(b"test"), GGUFValueType.STRING)
self.assertEqual(GGUFValueType.get_type(np.uint8(1)), GGUFValueType.UINT8)
self.assertEqual(GGUFValueType.get_type(np.uint16(1)), GGUFValueType.UINT16)
self.assertEqual(GGUFValueType.get_type(np.uint32(1)), GGUFValueType.UINT32)
self.assertEqual(GGUFValueType.get_type(np.uint64(1)), GGUFValueType.UINT64)
self.assertEqual(GGUFValueType.get_type(np.int8(-1)), GGUFValueType.INT8)
self.assertEqual(GGUFValueType.get_type(np.int16(-1)), GGUFValueType.INT16)
self.assertEqual(GGUFValueType.get_type(np.int32(-1)), GGUFValueType.INT32)
self.assertEqual(GGUFValueType.get_type(np.int64(-1)), GGUFValueType.INT64)
self.assertEqual(GGUFValueType.get_type(np.float32(1.0)), GGUFValueType.FLOAT32)
self.assertEqual(GGUFValueType.get_type(np.float64(1.0)), GGUFValueType.FLOAT64)
self.assertEqual(GGUFValueType.get_type({"k": 12}), GGUFValueType.OBJ)

if __name__ == '__main__':
unittest.main()

0 comments on commit 8d674b2

Please sign in to comment.