Skip to content

Commit

Permalink
Add SAm
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Jul 28, 2024
1 parent 541966d commit b7e2018
Show file tree
Hide file tree
Showing 13 changed files with 513 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ...

## Integrate into your own project

### 1. Add `usls` as a dependency to your project's `Cargo.toml`
### Add `usls` as a dependency to your project's `Cargo.toml`

```Shell
cargo add usls
Expand Down
Binary file added assets/truck.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 51 additions & 0 deletions examples/sam/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use usls::{
models::{SamPrompt, SAM},
Annotator, DataLoader, Options,
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
// encoder
let options_encoder = Options::default()
// .with_cpu()
.with_i00((1, 1, 1).into())
.with_model("mobile-sam-vit-t-encoder.onnx")?;

// decoder
let options_decoder = Options::default()
// .with_cpu()
.with_i11((1, 1, 1).into())
.with_i21((1, 1, 1).into())
.with_find_contours(true) // find contours or not
.with_model("mobile-sam-vit-t-decoder.onnx")?;

// build model
let mut model = SAM::new(options_encoder, options_decoder)?;

// build dataloader
let dl = DataLoader::default()
.with_batch(model.batch() as _)
.load("./assets/truck.jpg")?;

// build annotator
let annotator = Annotator::default()
.without_bboxes_name(true)
.without_bboxes_conf(true)
.without_mbrs_name(true)
.without_mbrs_conf(true)
.with_saveout("SAM");

// run & annotate
for (xs, _paths) in dl {
// prompt
let prompts = vec![
SamPrompt::default()
// .with_postive_point(774., 366.), // postive point
// .with_negative_point(774., 366.), // negative point
.with_bbox(215., 297., 643., 459.), // bbox
];
let ys = model.run(&xs, &prompts)?;
annotator.annotate(&xs, &ys);
}

Ok(())
}
3 changes: 2 additions & 1 deletion examples/yolo/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ fn main() -> Result<()> {
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
// .with_names(&coco::NAMES_80)
.with_names2(&coco::KEYPOINTS_NAMES_17)
.with_find_contours(false) // find contours or not
.with_profile(args.profile);
let mut model = YOLO::new(options)?;

Expand All @@ -164,7 +165,7 @@ fn main() -> Result<()> {
let annotator = Annotator::default()
.with_skeletons(&coco::SKELETONS_16)
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting.
.without_masks(false) // No masks plotting when doing segment task.
.with_saveout("YOLO-Series");

// run & annotate
Expand Down
14 changes: 7 additions & 7 deletions src/core/annotator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,6 @@ impl Annotator {
}
}

// masks
if !self.without_masks {
if let Some(xs) = &y.masks() {
self.plot_masks(&mut img_rgba, xs);
}
}

// bboxes
if !self.without_bboxes {
if let Some(xs) = &y.bboxes() {
Expand All @@ -368,6 +361,13 @@ impl Annotator {
}
}

// masks
if !self.without_masks {
if let Some(xs) = &y.masks() {
self.plot_masks(&mut img_rgba, xs);
}
}

// probs
if let Some(xs) = &y.probs() {
self.plot_probs(&mut img_rgba, xs);
Expand Down
30 changes: 30 additions & 0 deletions src/core/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ impl OrtEngine {
(3, 3) => Self::_set_ixx(x, &config.i33, i, ii).unwrap_or(x_default),
(3, 4) => Self::_set_ixx(x, &config.i34, i, ii).unwrap_or(x_default),
(3, 5) => Self::_set_ixx(x, &config.i35, i, ii).unwrap_or(x_default),
(4, 0) => Self::_set_ixx(x, &config.i40, i, ii).unwrap_or(x_default),
(4, 1) => Self::_set_ixx(x, &config.i41, i, ii).unwrap_or(x_default),
(4, 2) => Self::_set_ixx(x, &config.i42, i, ii).unwrap_or(x_default),
(4, 3) => Self::_set_ixx(x, &config.i43, i, ii).unwrap_or(x_default),
(4, 4) => Self::_set_ixx(x, &config.i44, i, ii).unwrap_or(x_default),
(4, 5) => Self::_set_ixx(x, &config.i45, i, ii).unwrap_or(x_default),
(5, 0) => Self::_set_ixx(x, &config.i50, i, ii).unwrap_or(x_default),
(5, 1) => Self::_set_ixx(x, &config.i51, i, ii).unwrap_or(x_default),
(5, 2) => Self::_set_ixx(x, &config.i52, i, ii).unwrap_or(x_default),
(5, 3) => Self::_set_ixx(x, &config.i53, i, ii).unwrap_or(x_default),
(5, 4) => Self::_set_ixx(x, &config.i54, i, ii).unwrap_or(x_default),
(5, 5) => Self::_set_ixx(x, &config.i55, i, ii).unwrap_or(x_default),
(6, 0) => Self::_set_ixx(x, &config.i60, i, ii).unwrap_or(x_default),
(6, 1) => Self::_set_ixx(x, &config.i61, i, ii).unwrap_or(x_default),
(6, 2) => Self::_set_ixx(x, &config.i62, i, ii).unwrap_or(x_default),
(6, 3) => Self::_set_ixx(x, &config.i63, i, ii).unwrap_or(x_default),
(6, 4) => Self::_set_ixx(x, &config.i64_, i, ii).unwrap_or(x_default),
(6, 5) => Self::_set_ixx(x, &config.i65, i, ii).unwrap_or(x_default),
(7, 0) => Self::_set_ixx(x, &config.i70, i, ii).unwrap_or(x_default),
(7, 1) => Self::_set_ixx(x, &config.i71, i, ii).unwrap_or(x_default),
(7, 2) => Self::_set_ixx(x, &config.i72, i, ii).unwrap_or(x_default),
(7, 3) => Self::_set_ixx(x, &config.i73, i, ii).unwrap_or(x_default),
(7, 4) => Self::_set_ixx(x, &config.i74, i, ii).unwrap_or(x_default),
(7, 5) => Self::_set_ixx(x, &config.i75, i, ii).unwrap_or(x_default),
_ => todo!(),
};
v_.push(x);
Expand Down Expand Up @@ -290,6 +314,12 @@ impl OrtEngine {
TensorElementType::Int64 => {
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
}
TensorElementType::Uint8 => {
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
}
TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
}
_ => todo!(),
};
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
Expand Down
176 changes: 175 additions & 1 deletion src/core/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,30 @@ pub struct Options {
pub i33: Option<MinOptMax>,
pub i34: Option<MinOptMax>,
pub i35: Option<MinOptMax>,

pub i40: Option<MinOptMax>,
pub i41: Option<MinOptMax>,
pub i42: Option<MinOptMax>,
pub i43: Option<MinOptMax>,
pub i44: Option<MinOptMax>,
pub i45: Option<MinOptMax>,
pub i50: Option<MinOptMax>,
pub i51: Option<MinOptMax>,
pub i52: Option<MinOptMax>,
pub i53: Option<MinOptMax>,
pub i54: Option<MinOptMax>,
pub i55: Option<MinOptMax>,
pub i60: Option<MinOptMax>,
pub i61: Option<MinOptMax>,
pub i62: Option<MinOptMax>,
pub i63: Option<MinOptMax>,
pub i64_: Option<MinOptMax>,
pub i65: Option<MinOptMax>,
pub i70: Option<MinOptMax>,
pub i71: Option<MinOptMax>,
pub i72: Option<MinOptMax>,
pub i73: Option<MinOptMax>,
pub i74: Option<MinOptMax>,
pub i75: Option<MinOptMax>,
// trt related
pub trt_engine_cache_enable: bool,
pub trt_int8_enable: bool,
Expand All @@ -63,6 +86,7 @@ pub struct Options {
pub yolo_task: Option<YOLOTask>,
pub yolo_version: Option<YOLOVersion>,
pub yolo_preds: Option<YOLOPreds>,
pub find_contours: bool,
}

impl Default for Options {
Expand Down Expand Up @@ -96,6 +120,30 @@ impl Default for Options {
i33: None,
i34: None,
i35: None,
i40: None,
i41: None,
i42: None,
i43: None,
i44: None,
i45: None,
i50: None,
i51: None,
i52: None,
i53: None,
i54: None,
i55: None,
i60: None,
i61: None,
i62: None,
i63: None,
i64_: None,
i65: None,
i70: None,
i71: None,
i72: None,
i73: None,
i74: None,
i75: None,
trt_engine_cache_enable: true,
trt_int8_enable: false,
trt_fp16_enable: false,
Expand All @@ -116,6 +164,7 @@ impl Default for Options {
yolo_task: None,
yolo_version: None,
yolo_preds: None,
find_contours: false,
}
}
}
Expand Down Expand Up @@ -171,6 +220,11 @@ impl Options {
self
}

pub fn with_find_contours(mut self, x: bool) -> Self {
self.find_contours = x;
self
}

pub fn with_names(mut self, names: &[&str]) -> Self {
self.names = Some(names.iter().map(|x| x.to_string()).collect::<Vec<String>>());
self
Expand Down Expand Up @@ -360,4 +414,124 @@ impl Options {
self.i35 = Some(x);
self
}

pub fn with_i40(mut self, x: MinOptMax) -> Self {
self.i40 = Some(x);
self
}

pub fn with_i41(mut self, x: MinOptMax) -> Self {
self.i41 = Some(x);
self
}

pub fn with_i42(mut self, x: MinOptMax) -> Self {
self.i42 = Some(x);
self
}

pub fn with_i43(mut self, x: MinOptMax) -> Self {
self.i43 = Some(x);
self
}

pub fn with_i44(mut self, x: MinOptMax) -> Self {
self.i44 = Some(x);
self
}

pub fn with_i45(mut self, x: MinOptMax) -> Self {
self.i45 = Some(x);
self
}

pub fn with_i50(mut self, x: MinOptMax) -> Self {
self.i50 = Some(x);
self
}

pub fn with_i51(mut self, x: MinOptMax) -> Self {
self.i51 = Some(x);
self
}

pub fn with_i52(mut self, x: MinOptMax) -> Self {
self.i52 = Some(x);
self
}

pub fn with_i53(mut self, x: MinOptMax) -> Self {
self.i53 = Some(x);
self
}

pub fn with_i54(mut self, x: MinOptMax) -> Self {
self.i54 = Some(x);
self
}

pub fn with_i55(mut self, x: MinOptMax) -> Self {
self.i55 = Some(x);
self
}

pub fn with_i60(mut self, x: MinOptMax) -> Self {
self.i60 = Some(x);
self
}

pub fn with_i61(mut self, x: MinOptMax) -> Self {
self.i61 = Some(x);
self
}

pub fn with_i62(mut self, x: MinOptMax) -> Self {
self.i62 = Some(x);
self
}

pub fn with_i63(mut self, x: MinOptMax) -> Self {
self.i63 = Some(x);
self
}

pub fn with_i64(mut self, x: MinOptMax) -> Self {
self.i64_ = Some(x);
self
}

pub fn with_i65(mut self, x: MinOptMax) -> Self {
self.i65 = Some(x);
self
}

pub fn with_i70(mut self, x: MinOptMax) -> Self {
self.i70 = Some(x);
self
}

pub fn with_i71(mut self, x: MinOptMax) -> Self {
self.i71 = Some(x);
self
}

pub fn with_i72(mut self, x: MinOptMax) -> Self {
self.i72 = Some(x);
self
}

pub fn with_i73(mut self, x: MinOptMax) -> Self {
self.i73 = Some(x);
self
}

pub fn with_i74(mut self, x: MinOptMax) -> Self {
self.i74 = Some(x);
self
}

pub fn with_i75(mut self, x: MinOptMax) -> Self {
self.i75 = Some(x);
self
}
}
6 changes: 6 additions & 0 deletions src/core/x.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ impl From<Array<f32, IxDyn>> for X {
}
}

impl From<Vec<f32>> for X {
fn from(x: Vec<f32>) -> Self {
Self(Array::from_vec(x).into_dyn().into_owned())
}
}

impl std::ops::Deref for X {
type Target = Array<f32, IxDyn>;

Expand Down
Loading

0 comments on commit b7e2018

Please sign in to comment.