Skip to content

Commit

Permalink
Merge pull request #27 from Synerise/seed-train-init
Browse files Browse the repository at this point in the history
Init embedding with seed during training.
  • Loading branch information
piobab authored Mar 22, 2021
2 parents f00187a + 6ebeb61 commit 4a6df21
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Command line options (for more info use `--help` as program argument):
-r --relation-name (name of the relation, for output filename generation)
-d --dimension (number of dimensions for output embeddings)
-n --number-of-iterations (number of iterations for the algorithm, usually 3 or 4 works well)
-s --seed (seed integer for embedding initialization)
-c --columns (column format specification)
-p --prepend-field-name (prepend field name to entity in output)
-l --log-every-n (log output every N lines)
Expand Down
4 changes: 4 additions & 0 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub struct Configuration {
/// Maximum number of iteration for training
pub max_number_of_iteration: u8,

/// Seed for embedding initialization
pub seed: Option<i64>,

/// Prepend field name to entity in the output file. It differentiates entities with the same
/// name from different columns
pub prepend_field: bool,
Expand Down Expand Up @@ -78,6 +81,7 @@ impl Configuration {
produce_entity_occurrence_count: true,
embeddings_dimension: 128,
max_number_of_iteration: 4,
seed: None,
prepend_field: true,
log_every_n: 1000,
in_memory_embedding_calculation: true,
Expand Down
15 changes: 11 additions & 4 deletions src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ trait MatrixWrapper {
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
rows: usize,
cols: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
) -> Self;

Expand Down Expand Up @@ -55,14 +56,15 @@ impl MatrixWrapper for TwoDimVectorMatrix {
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
rows: usize,
cols: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
) -> Self {
let result: Vec<Vec<f32>> = (0..cols)
.into_par_iter()
.map(|i| {
let mut col: Vec<f32> = Vec::with_capacity(rows);
for hsh in sparse_matrix_reader.iter_hashes() {
let col_value = init_value(i, hsh.value);
let col_value = init_value(i, hsh.value, fixed_random_value);
col.push(col_value);
}
col
Expand Down Expand Up @@ -129,8 +131,8 @@ impl MatrixWrapper for TwoDimVectorMatrix {
}
}

fn init_value(col: usize, hsh: u64) -> f32 {
((hash((hsh as i64) + (col as i64)) % MAX_HASH_I64) as f32) / MAX_HASH_F32
fn init_value(col: usize, hsh: u64, fixed_random_value: i64) -> f32 {
((hash((hsh as i64) + (col as i64) + fixed_random_value) % MAX_HASH_I64) as f32) / MAX_HASH_F32
}

fn hash(num: i64) -> i64 {
Expand Down Expand Up @@ -160,6 +162,7 @@ impl MatrixWrapper for MMapMatrix {
fn init_with_hashes<T: SparseMatrixReader + Sync + Send>(
rows: usize,
cols: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
) -> Self {
let uuid = Uuid::new_v4();
Expand All @@ -172,7 +175,7 @@ impl MatrixWrapper for MMapMatrix {
// i - number of dimension
// chunk - column/vector of bytes
for (j, hsh) in sparse_matrix_reader.iter_hashes().enumerate() {
let col_value = init_value(i, hsh.value);
let col_value = init_value(i, hsh.value, fixed_random_value);
MMapMatrix::update_column(j, chunk, |value| unsafe { *value = col_value });
}
});
Expand Down Expand Up @@ -338,6 +341,7 @@ pub fn calculate_embeddings<T1, T2>(
struct MatrixMultiplicator<T: SparseMatrixReader + Sync + Send, M: MatrixWrapper> {
dimension: usize,
number_of_entities: usize,
fixed_random_value: i64,
sparse_matrix_reader: Arc<T>,
_marker: PhantomData<M>,
}
Expand All @@ -348,9 +352,11 @@ where
M: MatrixWrapper,
{
fn new(config: Arc<Configuration>, sparse_matrix_reader: Arc<T>) -> Self {
let rand_value = config.seed.map(hash).unwrap_or(0);
Self {
dimension: config.embeddings_dimension as usize,
number_of_entities: sparse_matrix_reader.get_number_of_entities() as usize,
fixed_random_value: rand_value,
sparse_matrix_reader,
_marker: PhantomData,
}
Expand All @@ -366,6 +372,7 @@ where
let result = M::init_with_hashes(
self.number_of_entities,
self.dimension,
self.fixed_random_value,
self.sparse_matrix_reader.clone(),
);

Expand Down
9 changes: 9 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ fn main() {
.help("Max number of iterations")
.takes_value(true),
)
.arg(
Arg::with_name("seed")
.short("s")
.long("seed")
.help("Seed (integer) for embedding initialization")
.takes_value(true),
)
.arg(
Arg::with_name("columns")
.short("c")
Expand Down Expand Up @@ -140,6 +147,7 @@ fn main() {
.unwrap()
.parse()
.unwrap();
let seed: Option<i64> = matches.value_of("seed").map(|s| s.parse().unwrap());
let relation_name = matches.value_of("relation-name").unwrap();
let prepend_field_name = {
let value: u8 = matches
Expand Down Expand Up @@ -180,6 +188,7 @@ fn main() {
produce_entity_occurrence_count: true,
embeddings_dimension: dimension,
max_number_of_iteration: max_iter,
seed,
prepend_field: prepend_field_name,
log_every_n: log_every,
in_memory_embedding_calculation,
Expand Down
1 change: 1 addition & 0 deletions tests/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ fn prepare_config() -> Configuration {
produce_entity_occurrence_count: true,
embeddings_dimension: 128,
max_number_of_iteration: 4,
seed: None,
prepend_field: false,
log_every_n: 10000,
in_memory_embedding_calculation: true,
Expand Down

0 comments on commit 4a6df21

Please sign in to comment.