From 3b9e3fed10c952d53c14a5c6be41241b31ae606e Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:12:55 +0000 Subject: [PATCH] more code cleanup --- llama-rs/src/lib.rs | 36 ++++++++++++++++++++++++------------ llama-rs/src/loader.rs | 7 ------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 5e44979c..2c57c8db 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -60,6 +60,15 @@ struct Layer { w3: ggml::Tensor, } + +/// Model Version +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) enum ModelVersion { + GGMF, + GGJT, + Unversioned, +} + /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. pub struct Model { @@ -75,6 +84,8 @@ pub struct Model { tensors: HashMap, mmap: Option, + + version: ModelVersion, // Must be kept alive for the model _context: ggml::Context, @@ -604,10 +615,10 @@ impl Model { let mut reader = BufReader::new(&file); // Verify magic - let model_type: ModelType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned, + let model_type: ModelVersion = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF, + ggml::FILE_MAGIC_GGJT => ModelVersion::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelVersion::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -617,13 +628,13 @@ impl Model { // Load format version match model_type { - ModelType::GGMF | ModelType::GGJT => { + ModelVersion::GGMF | ModelVersion::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion { value: version }), }; } - ModelType::Unversioned => {} + ModelVersion::Unversioned => {} } // ================= @@ -660,8 +671,8 @@ impl Model { for i in 0..hparams.n_vocab { let len = match model_type { // `read_i32` maybe a typo - ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, - ModelType::GGJT => read_u32(&mut reader)? as usize, + ModelVersion::GGMF | ModelVersion::Unversioned => read_i32(&mut reader)? as usize, + ModelVersion::GGJT => read_u32(&mut reader)? as usize, }; let maybe_word = if len > 0 { read_string(&mut reader, len) @@ -682,12 +693,12 @@ impl Model { // Token score, currently unused match model_type { - ModelType::GGMF | ModelType::GGJT => { + ModelVersion::GGMF | ModelVersion::GGJT => { if let Ok(score) = read_f32(&mut reader) { id_to_token_score.push(score); } } - ModelType::Unversioned => { + ModelVersion::Unversioned => { // Legacy model, set empty score id_to_token_score.push(0.); } @@ -815,11 +826,12 @@ impl Model { tensors, _context: context, mmap: None, + version: model_type, } }; match model_type { - ModelType::GGMF | ModelType::Unversioned => { + ModelVersion::GGMF | ModelVersion::Unversioned => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -829,7 +841,7 @@ impl Model { &model, )? } - ModelType::GGJT => { + ModelVersion::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; load_weights_ggjt( &mut reader, diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index e7326e13..602bd080 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -43,13 +43,6 @@ fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } -#[derive(PartialEq)] -pub(crate) enum ModelType { - GGMF, - GGJT, - Unversioned, -} - pub(crate) fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path,