Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init embedding with seed during training. #27

Merged
merged 1 commit into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Command line options (for more info use `--help` as program argument):
-r --relation-name (name of the relation, for output filename generation)
-d --dimension (number of dimensions for output embeddings)
-n --number-of-iterations (number of iterations for the algorithm, usually 3 or 4 works well)
-s --seed (seed integer for embedding initialization)
-c --columns (column format specification)
-p --prepend-field-name (prepend field name to entity in output)
-l --log-every-n (log output every N lines)
Expand Down
4 changes: 4 additions & 0 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub struct Configuration {
/// Maximum number of iteration for training
pub max_number_of_iteration: u8,

/// Seed for embedding initialization
pub seed: Option<i64>,

/// Prepend field name to entity in the output file. It differentiates entities with the same
/// name from different columns
pub prepend_field: bool,
Expand Down Expand Up @@ -78,6 +81,7 @@ impl Configuration {
produce_entity_occurrence_count: true,
embeddings_dimension: 128,
max_number_of_iteration: 4,
seed: None,
prepend_field: true,
log_every_n: 1000,
in_memory_embedding_calculation: true,
Expand Down
15 changes: 11 additions & 4 deletions src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ trait MatrixWrapper {
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
rows: usize,
cols: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
) -> Self;

Expand Down Expand Up @@ -55,14 +56,15 @@ impl MatrixWrapper for TwoDimVectorMatrix {
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
rows: usize,
cols: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
) -> Self {
let result: Vec<Vec<f32>> = (0..cols)
.into_par_iter()
.map(|i| {
let mut col: Vec<f32> = Vec::with_capacity(rows);
for hsh in sparse_matrix_reader.iter_hashes() {
let col_value = init_value(i, hsh.value);
let col_value = init_value(i, hsh.value, fixed_random_value);
col.push(col_value);
}
col
Expand Down Expand Up @@ -129,8 +131,8 @@ impl MatrixWrapper for TwoDimVectorMatrix {
}
}

fn init_value(col: usize, hsh: u64) -> f32 {
((hash((hsh as i64) + (col as i64)) % MAX_HASH_I64) as f32) / MAX_HASH_F32
fn init_value(col: usize, hsh: u64, fixed_random_value: i64) -> f32 {
((hash((hsh as i64) + (col as i64) + fixed_random_value) % MAX_HASH_I64) as f32) / MAX_HASH_F32
}

fn hash(num: i64) -> i64 {
Expand Down Expand Up @@ -160,6 +162,7 @@ impl MatrixWrapper for MMapMatrix {
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
rows: usize,
cols: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
) -> Self {
let uuid = Uuid::new_v4();
Expand All @@ -172,7 +175,7 @@ impl MatrixWrapper for MMapMatrix {
// i - number of dimension
// chunk - column/vector of bytes
for (j, hsh) in sparse_matrix_reader.iter_hashes().enumerate() {
let col_value = init_value(i, hsh.value);
let col_value = init_value(i, hsh.value, fixed_random_value);
MMapMatrix::update_column(j, chunk, |value| unsafe { *value = col_value });
}
});
Expand Down Expand Up @@ -338,6 +341,7 @@ pub fn calculate_embeddings<T1, T2>(
struct MatrixMultiplicator<T: SparseMatrixReader + Sync + Send, M: MatrixWrapper> {
dimension: usize,
number_of_entities: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
_marker: PhantomData<M>,
}
Expand All @@ -348,9 +352,11 @@ where
M: MatrixWrapper,
{
fn new(config: Arc<Configuration>, sparse_matrix_reader: Arc<T>) -> Self {
let rand_value = config.seed.map(hash).unwrap_or(0);
Self {
dimension: config.embeddings_dimension as usize,
number_of_entities: sparse_matrix_reader.get_number_of_entities() as usize,
fixed_random_value: rand_value,
sparse_matrix_reader,
_marker: PhantomData,
}
Expand All @@ -366,6 +372,7 @@ where
let result = M::init_with_hashes(
self.number_of_entities,
self.dimension,
self.fixed_random_value,
self.sparse_matrix_reader.clone(),
);

Expand Down
9 changes: 9 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ fn main() {
.help("Max number of iterations")
.takes_value(true),
)
.arg(
Arg::with_name("seed")
.short("s")
.long("seed")
.help("Seed (integer) for embedding initialization")
.takes_value(true),
)
.arg(
Arg::with_name("columns")
.short("c")
Expand Down Expand Up @@ -140,6 +147,7 @@ fn main() {
.unwrap()
.parse()
.unwrap();
let seed: Option<i64> = matches.value_of("seed").map(|s| s.parse().unwrap());
let relation_name = matches.value_of("relation-name").unwrap();
let prepend_field_name = {
let value: u8 = matches
Expand Down Expand Up @@ -180,6 +188,7 @@ fn main() {
produce_entity_occurrence_count: true,
embeddings_dimension: dimension,
max_number_of_iteration: max_iter,
seed,
prepend_field: prepend_field_name,
log_every_n: log_every,
in_memory_embedding_calculation,
Expand Down
1 change: 1 addition & 0 deletions tests/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ fn prepare_config() -> Configuration {
produce_entity_occurrence_count: true,
embeddings_dimension: 128,
max_number_of_iteration: 4,
seed: None,
prepend_field: false,
log_every_n: 10000,
in_memory_embedding_calculation: true,
Expand Down