Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing parameters and simple prompt on SAM CLI #598

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 86 additions & 24 deletions examples/sam/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,22 @@ struct sam_image_f32 {
std::vector<float> data;
};

struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
float mask_threshold = 0.f;
float iou_threshold = 0.88f;
float stability_score_threshold = 0.95f;
float stability_score_offset = 1.0f;
float eps = 1e-6f;
float eps_decoder_transformer = 1e-5f;
sam_point pt = { 414.375f, 162.796875f, };
};

void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
printf("%s\n", title);
float * data = (float *)t->data;
Expand Down Expand Up @@ -469,12 +485,12 @@ bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) {
}

// load the model's weights from a file
bool sam_model_load(const std::string & fname, sam_model & model) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
bool sam_model_load(const sam_params & params, sam_model & model) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str());

auto fin = std::ifstream(fname, std::ios::binary);
auto fin = std::ifstream(params.model, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str());
return false;
}

Expand All @@ -483,13 +499,21 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str());
return false;
}
}

// load hparams
{
// Override defaults with user choices
model.hparams.mask_threshold = params.mask_threshold;
model.hparams.iou_threshold = params.iou_threshold;
model.hparams.stability_score_threshold = params.stability_score_threshold;
model.hparams.stability_score_offset = params.stability_score_offset;
model.hparams.eps = params.eps;
model.hparams.eps_decoder_transformer = params.eps_decoder_transformer;

auto & hparams = model.hparams;

fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state));
Expand All @@ -510,14 +534,15 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
printf("%s: qntvr = %d\n", __func__, qntvr);

hparams.ftype %= GGML_QNT_VERSION_FACTOR;

}

// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
__func__, params.model.c_str(), model.hparams.ftype);
return false;
}

Expand Down Expand Up @@ -1791,7 +1816,7 @@ bool sam_decode_mask(
return true;
}

bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname) {
if (state.low_res_masks->ne[2] == 0) return true;
if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
Expand Down Expand Up @@ -1938,7 +1963,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state
printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n",
i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy);

std::string filename = "mask_out_" + std::to_string(i) + ".png";
std::string filename = fname + std::to_string(i) + ".png";
if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) {
printf("%s: failed to write mask %s\n", __func__, filename.c_str());
return false;
Expand Down Expand Up @@ -1967,7 +1992,7 @@ struct ggml_cgraph * sam_build_fast_graph(

prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point);
if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {
fprintf(stderr, "%s: failed to encode prompt\n", __func__);
fprintf(stderr, "%s: failed to encode prompt (%f, %f)\n", __func__, point.x, point.y);
return {};
}

Expand All @@ -1986,14 +2011,6 @@ struct ggml_cgraph * sam_build_fast_graph(

return gf;
}
struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
};

void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
Expand All @@ -2007,7 +2024,23 @@ void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, " -i FNAME, --inp FNAME\n");
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -o FNAME, --out FNAME\n");
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, " mask file name prefix (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, "SAM hyperparameters:\n");
fprintf(stderr, " -mt FLOAT, --mask-threshold\n");
fprintf(stderr, " mask threshold (default: %f)\n", params.mask_threshold);
fprintf(stderr, " -it FLOAT, --iou-threshold\n");
fprintf(stderr, " iou threshold (default: %f)\n", params.iou_threshold);
fprintf(stderr, " -st FLOAT, --score-threshold\n");
fprintf(stderr, " score threshold (default: %f)\n", params.stability_score_threshold);
fprintf(stderr, " -so FLOAT, --score-offset\n");
fprintf(stderr, " score offset (default: %f)\n", params.stability_score_offset);
fprintf(stderr, " -e FLOAT, --epsilon\n");
fprintf(stderr, " epsilon (default: %f)\n", params.eps);
fprintf(stderr, " -ed FLOAT, --epsilon-decoder-transformer\n");
fprintf(stderr, " epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer);
fprintf(stderr, "SAM prompt:\n");
fprintf(stderr, " -p TUPLE, --point-prompt\n");
fprintf(stderr, " point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.pt.x, params.pt.y);
fprintf(stderr, "\n");
}

Expand All @@ -2025,6 +2058,34 @@ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
params.fname_inp = argv[++i];
} else if (arg == "-o" || arg == "--out") {
params.fname_out = argv[++i];
} else if (arg == "-mt" || arg == "--mask-threshold") {
params.mask_threshold = std::stof(argv[++i]);
} else if (arg == "-it" || arg == "--iou-threshold") {
params.iou_threshold = std::stof(argv[++i]);
} else if (arg == "-st" || arg == "--score-threshold") {
params.stability_score_threshold = std::stof(argv[++i]);
} else if (arg == "-so" || arg == "--score-offset") {
params.stability_score_offset = std::stof(argv[++i]);
} else if (arg == "-e" || arg == "--epsilon") {
params.eps = std::stof(argv[++i]);
} else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") {
params.eps_decoder_transformer = std::stof(argv[++i]);
} else if (arg == "-p" || arg == "--point-prompt") {
// TODO multiple points per model invocation
char* point = argv[++i];

char* coord = strtok(point, ",");
if (!coord){
fprintf(stderr, "Error while parsing prompt!\n");
exit(1);
}
params.pt.x = std::stof(coord);
coord = strtok(NULL, ",");
if (!coord){
fprintf(stderr, "Error while parsing prompt!\n");
exit(1);
}
params.pt.y = std::stof(coord);
} else if (arg == "-h" || arg == "--help") {
sam_print_usage(argc, argv, params);
exit(0);
Expand Down Expand Up @@ -2078,7 +2139,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();

if (!sam_model_load(params.model, model)) {
if (!sam_model_load(params, model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand Down Expand Up @@ -2147,10 +2208,11 @@ int main(int argc, char ** argv) {
state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
state.allocr = ggml_allocr_new_measure(tensor_alignment);

// TODO: user input
const sam_point pt = { 414.375f, 162.796875f, };
// TODO: more varied prompts
fprintf(stderr, "prompt: (%f, %f)\n", params.pt.x, params.pt.y);

// measure memory requirements for the graph
struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
if (!gf_measure) {
fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
return 1;
Expand All @@ -2166,7 +2228,7 @@ int main(int argc, char ** argv) {
// compute the graph with the measured exact memory requirements from above
ggml_allocr_reset(state.allocr);

struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
if (!gf) {
fprintf(stderr, "%s: failed to build fast graph\n", __func__);
return 1;
Expand All @@ -2182,7 +2244,7 @@ int main(int argc, char ** argv) {
state.allocr = NULL;
}

if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out)) {
fprintf(stderr, "%s: failed to write masks\n", __func__);
return 1;
}
Expand Down