Skip to content

Commit

Permalink
Add utility methods to split gradients to GradientParams (#2311)
Browse files Browse the repository at this point in the history
* Add utility methods to split gradients to GradientParams

* Forget about from_segments for now

* ParamId -> u64

* Fix test

* Don't need two lifetimes

* Backwards compatibilty deserialization

* Always serialize same format

* Better compat with old formats

* Fix no_std

* Add a bit to the book

* Move deserialize function, add test

* Tweak test

* Add backward compat for 16-byte uuid

---------

Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
ArthurBrussee and laggui authored Oct 4, 2024
1 parent 8327deb commit dbd577a
Show file tree
Hide file tree
Showing 24 changed files with 248 additions and 139 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ You can implement your own mapper or visitor by implementing these simple traits
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
fn visit<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>);
}
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) ->
fn map<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) ->
Tensor<B, D>;
}
```
Expand Down
31 changes: 31 additions & 0 deletions burn-book/src/custom-training-loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,37 @@ You can find the code above available as an
[example](https://github.com/tracel-ai/burn/tree/main/examples/custom-training-loop) for you to
test.

## Multiple optimizers

It's common practice to set different learning rates, optimizer parameters, or use different optimizers entirely, for different parts
of a model. In Burn, each `GradientParams` can contain only a subset of gradients to actually apply with an optimizer.
This allows you to flexibly mix and match optimizers!

```rust,ignore
// Start with calculating all gradients
let grads = loss.backward();
// Now split the gradients into various parts.
let grads_conv1 = GradientParams::from_module(&mut grads, &model.conv1);
let grads_conv2 = GradientParams::from_module(&mut grads, &model.conv2);
// You can step the model with these gradients, using different learning
// rates for each param. You could also use an entirely different optimizer here!
model = optim.step(config.lr * 2.0, model, grads_conv1);
model = optim.step(config.lr * 4.0, model, grads_conv2);
// For even more granular control you can split off individual parameter
// eg. a linear bias usually needs a smaller learning rate.
if let Some(bias) == model.linear1.bias {
let grads_bias = GradientParams::from_params(&mut grads, &model.linear1, &[bias.id]);
model = optim.step(config.lr * 0.1, model, grads_bias);
}
// Note that above calls remove gradients, so we can just get all "remaining" gradients.
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(config.lr, model, grads);
```

## Custom Type

The explanations above demonstrate how to create a basic training loop. However, you may find it
Expand Down
2 changes: 1 addition & 1 deletion burn-book/src/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ let model = model.quantize_weights(&mut quantizer);
> impl<B: Backend> ModuleMapper<B> for Dequantize {
> fn map_float<const D: usize>(
> &mut self,
> _id: &ParamId,
> _id: ParamId,
> tensor: Tensor<B, D>,
> ) -> Tensor<B, D> {
> tensor.dequantize()
Expand Down
2 changes: 0 additions & 2 deletions crates/burn-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ web-time = { version = "1.1.0" }


[dependencies]
data-encoding = { workspace = true }

# Network downloader
indicatif = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
Expand Down
26 changes: 7 additions & 19 deletions crates/burn-common/src/id.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,33 @@
use alloc::string::String;

use crate::rand::gen_random;

use data_encoding::BASE32_DNSSEC;

/// Simple ID generator.
pub struct IdGenerator {}

impl IdGenerator {
/// Generates a new ID.
pub fn generate() -> String {
// Generate 6 random bytes (281,474,976,710,656 combinations)
let random_bytes: [u8; 6] = gen_random();

// Encode the random bytes in base32 DNSSEC
// 6 bytes encodes to 10 lower case characters, e.g. "3uu5e6vv7c"
BASE32_DNSSEC.encode(&random_bytes)
pub fn generate() -> u64 {
// Generate a random u64 (18,446,744,073,709,551,615 combinations)
let random_bytes: [u8; 8] = gen_random();
u64::from_le_bytes(random_bytes)
}
}

#[cfg(test)]
mod tests {
use super::*;

use alloc::{collections::BTreeSet, string::String};
use alloc::collections::BTreeSet;

#[cfg(feature = "std")]
use dashmap::DashSet; //Concurrent HashMap
#[cfg(feature = "std")]
use std::{sync::Arc, thread};

#[test]
fn not_empty_test() {
assert!(!IdGenerator::generate().is_empty());
}

#[test]
fn uniqueness_test() {
const IDS_CNT: usize = 10_000;

let mut set: BTreeSet<String> = BTreeSet::new();
let mut set: BTreeSet<u64> = BTreeSet::new();

for _i in 0..IDS_CNT {
assert!(set.insert(IdGenerator::generate()));
Expand All @@ -55,7 +43,7 @@ mod tests {
const NUM_REPEATS: usize = 1_000;
const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;

let set: Arc<DashSet<String>> = Arc::new(DashSet::new());
let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());

let mut handles = vec![];

Expand Down
5 changes: 4 additions & 1 deletion crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true
burn-tch = { path = "../burn-tch", version = "0.15.0", optional = true }
burn-candle = { path = "../burn-candle", version = "0.15.0", optional = true }

data-encoding = { workspace = true }
uuid = { workspace = true }

derive-new = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std
Expand All @@ -136,7 +139,7 @@ serde_json = { workspace = true, features = ["alloc"] } #Default enables std
thiserror = { workspace = true, optional = true }
regex = { workspace = true, optional = true }
num-traits = { workspace = true }
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled

[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
Expand Down
16 changes: 8 additions & 8 deletions crates/burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ macro_rules! module {
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map_float<const D: usize>(
&mut self,
_id: &ParamId,
_id: ParamId,
tensor: Tensor<B, D>,
) -> Tensor<B, D> {
let func = $item;
Expand All @@ -35,7 +35,7 @@ macro_rules! module {
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit_float<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
fn visit_float<const D: usize>(&mut self, _id: ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
Expand Down Expand Up @@ -212,31 +212,31 @@ pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a float tensor in the module.
fn visit_float<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D>) {}
fn visit_float<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D>) {}
/// Visit an int tensor in the module.
fn visit_int<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Int>) {}
fn visit_int<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D, Int>) {}
/// Visit a bool tensor in the module.
fn visit_bool<const D: usize>(&mut self, _id: &ParamId, _tensor: &Tensor<B, D, Bool>) {}
fn visit_bool<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D, Bool>) {}
}

/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a float tensor in the module.
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor
}
/// Map an int tensor in the module.
fn map_int<const D: usize>(
&mut self,
_id: &ParamId,
_id: ParamId,
tensor: Tensor<B, D, Int>,
) -> Tensor<B, D, Int> {
tensor
}
/// Map a bool tensor in the module.
fn map_bool<const D: usize>(
&mut self,
_id: &ParamId,
_id: ParamId,
tensor: Tensor<B, D, Bool>,
) -> Tensor<B, D, Bool> {
tensor
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/module/param/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl<T: Parameter> Param<T> {

impl<T: Parameter> Clone for Param<T> {
fn clone(&self) -> Self {
Param::initialized(self.id.clone(), self.val())
Param::initialized(self.id, self.val())
}
}

Expand Down
96 changes: 80 additions & 16 deletions crates/burn-core/src/module/param/id.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
use alloc::string::{String, ToString};
use core::hash::{BuildHasher, Hasher};

use alloc::string::String;
use burn_common::id::IdGenerator;
use data_encoding::BASE32_DNSSEC;
use hashbrown::hash_map::DefaultHashBuilder;

/// Parameter ID.
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct ParamId {
value: String,
}

impl From<&str> for ParamId {
fn from(val: &str) -> Self {
Self {
value: val.to_string(),
}
}
value: u64,
}

impl From<String> for ParamId {
fn from(value: String) -> Self {
impl From<u64> for ParamId {
fn from(value: u64) -> Self {
Self { value }
}
}
Expand All @@ -35,14 +31,82 @@ impl ParamId {
}
}

/// Convert the parameter ID into a string.
pub fn into_string(self) -> String {
/// Gets the internal value of the id.
pub fn val(&self) -> u64 {
self.value
}

/// Convert the parameter ID into a string.
pub fn serialize(self) -> String {
BASE32_DNSSEC.encode(&self.value.to_le_bytes())
}

/// Deserialize a param id.
///
/// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).
pub fn deserialize(encoded: &str) -> ParamId {
let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
Ok(bytes) => {
let mut buffer = [0u8; 8];
buffer[..bytes.len()].copy_from_slice(&bytes);
u64::from_le_bytes(buffer)
}
Err(err) => match uuid::Uuid::try_parse(encoded) {
// Backward compatibility with uuid parameter identifiers
Ok(id) => {
// Hash the 128-bit uuid to 64-bit
// Though not *theoretically* unique, the probability of a collision should be extremly low
let mut hasher = DefaultHashBuilder::default().build_hasher();
// let mut hasher = DefaultHasher::new();
hasher.write(id.as_bytes());
hasher.finish()
}
Err(_) => panic!("Invalid id. {err}"),
},
};

ParamId::from(u64_id)
}
}

impl core::fmt::Display for ParamId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(self.value.as_str())
f.write_str(&self.serialize())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn param_serde_deserialize() {
let val = ParamId::from(123456u64);
let deserialized = ParamId::deserialize(&val.serialize());
assert_eq!(val, deserialized);
}

#[test]
fn param_serde_deserialize_legacy() {
let legacy_val = [45u8; 6];
let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
}

#[test]
fn param_serde_deserialize_legacy_uuid() {
// Ensure support for legacy uuid deserialization and make sure it results in the same output
let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
let param_id1 = ParamId::deserialize(legacy_id);
let param_id2 = ParamId::deserialize(legacy_id);
assert_eq!(param_id1, param_id2);
}

#[test]
#[should_panic = "Invalid id."]
fn param_serde_deserialize_invalid_id() {
let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
let _ = ParamId::deserialize(invalid_uuid);
}
}
7 changes: 3 additions & 4 deletions crates/burn-core/src/module/param/running.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
let tensor = self.value.lock().unwrap();

visitor.visit_float(&self.id, &tensor)
visitor.visit_float(self.id, &tensor)
}

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut tensor = self.value.lock().unwrap();
let tensor_out = mapper.map_float(&self.id, tensor.clone());
let tensor_out = mapper.map_float(self.id, tensor.clone());

*tensor = tensor_out;
core::mem::drop(tensor);
Expand Down Expand Up @@ -246,6 +245,6 @@ impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tens
self.sync();
let value = self.value();

RunningState::with_id(self.id.clone(), value.inner())
RunningState::with_id(self.id, value.inner())
}
}
Loading

0 comments on commit dbd577a

Please sign in to comment.