diff --git a/examples/sam/main.cpp b/examples/sam/main.cpp index 7c4130fdd..38d5e2734 100644 --- a/examples/sam/main.cpp +++ b/examples/sam/main.cpp @@ -296,6 +296,22 @@ struct sam_image_f32 { std::vector data; }; +struct sam_params { + int32_t seed = -1; // RNG seed + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path + std::string fname_inp = "img.jpg"; + std::string fname_out = "img.out"; + float mask_threshold = 0.f; + float iou_threshold = 0.88f; + float stability_score_threshold = 0.95f; + float stability_score_offset = 1.0f; + float eps = 1e-6f; + float eps_decoder_transformer = 1e-5f; + sam_point pt = { 414.375f, 162.796875f, }; +}; + void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) { printf("%s\n", title); float * data = (float *)t->data; @@ -469,12 +485,12 @@ bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) { } // load the model's weights from a file -bool sam_model_load(const std::string & fname, sam_model & model) { - fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); +bool sam_model_load(const sam_params & params, sam_model & model) { + fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str()); - auto fin = std::ifstream(fname, std::ios::binary); + auto fin = std::ifstream(params.model, std::ios::binary); if (!fin) { - fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str()); return false; } @@ -483,13 +499,21 @@ bool sam_model_load(const std::string & fname, sam_model & model) { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); if (magic != 0x67676d6c) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str()); return false; } } // load hparams { + // Override defaults with user choices + model.hparams.mask_threshold = params.mask_threshold; + model.hparams.iou_threshold = params.iou_threshold; + model.hparams.stability_score_threshold = params.stability_score_threshold; + model.hparams.stability_score_offset = params.stability_score_offset; + model.hparams.eps = params.eps; + model.hparams.eps_decoder_transformer = params.eps_decoder_transformer; + auto & hparams = model.hparams; fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state)); @@ -510,6 +534,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) { printf("%s: qntvr = %d\n", __func__, qntvr); hparams.ftype %= GGML_QNT_VERSION_FACTOR; + } // for the big tensors, we have the option to store the data in 16-bit floats or quantized @@ -517,7 +542,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) { ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); if (wtype == GGML_TYPE_COUNT) { fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", - __func__, fname.c_str(), model.hparams.ftype); + __func__, params.model.c_str(), model.hparams.ftype); return false; } @@ -1791,7 +1816,7 @@ bool sam_decode_mask( return true; } -bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) { +bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname) { if (state.low_res_masks->ne[2] == 0) return true; if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) { printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]); @@ -1938,7 +1963,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n", i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy); - std::string filename = "mask_out_" + std::to_string(i) + ".png"; + std::string filename = fname + std::to_string(i) + ".png"; if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) { printf("%s: failed to write mask %s\n", __func__, filename.c_str()); return false; @@ -1967,7 +1992,7 @@ struct ggml_cgraph * sam_build_fast_graph( prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point); if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) { - fprintf(stderr, "%s: failed to encode prompt\n", __func__); + fprintf(stderr, "%s: failed to encode prompt (%f, %f)\n", __func__, point.x, point.y); return {}; } @@ -1986,14 +2011,6 @@ struct ggml_cgraph * sam_build_fast_graph( return gf; } -struct sam_params { - int32_t seed = -1; // RNG seed - int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - - std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path - std::string fname_inp = "img.jpg"; - std::string fname_out = "img.out"; -}; void sam_print_usage(int argc, char ** argv, const sam_params & params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); @@ -2007,7 +2024,23 @@ void sam_print_usage(int argc, char ** argv, const sam_params & params) { fprintf(stderr, " -i FNAME, --inp FNAME\n"); fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str()); fprintf(stderr, " -o FNAME, --out FNAME\n"); - fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str()); + fprintf(stderr, " mask file name prefix (default: %s)\n", params.fname_out.c_str()); + fprintf(stderr, "SAM hyperparameters:\n"); + fprintf(stderr, " -mt FLOAT, --mask-threshold\n"); + fprintf(stderr, " mask threshold (default: %f)\n", params.mask_threshold); + fprintf(stderr, " -it FLOAT, --iou-threshold\n"); + fprintf(stderr, " iou threshold (default: %f)\n", params.iou_threshold); + fprintf(stderr, " -st FLOAT, --score-threshold\n"); + fprintf(stderr, " score threshold (default: %f)\n", params.stability_score_threshold); + fprintf(stderr, " -so FLOAT, --score-offset\n"); + fprintf(stderr, " score offset (default: %f)\n", params.stability_score_offset); + fprintf(stderr, " -e FLOAT, --epsilon\n"); + fprintf(stderr, " epsilon (default: %f)\n", params.eps); + fprintf(stderr, " -ed FLOAT, --epsilon-decoder-transformer\n"); + fprintf(stderr, " epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer); + fprintf(stderr, "SAM prompt:\n"); + fprintf(stderr, " -p TUPLE, --point-prompt\n"); + fprintf(stderr, " point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.pt.x, params.pt.y); fprintf(stderr, "\n"); } @@ -2025,6 +2058,34 @@ bool sam_params_parse(int argc, char ** argv, sam_params & params) { params.fname_inp = argv[++i]; } else if (arg == "-o" || arg == "--out") { params.fname_out = argv[++i]; + } else if (arg == "-mt" || arg == "--mask-threshold") { + params.mask_threshold = std::stof(argv[++i]); + } else if (arg == "-it" || arg == "--iou-threshold") { + params.iou_threshold = std::stof(argv[++i]); + } else if (arg == "-st" || arg == "--score-threshold") { + params.stability_score_threshold = std::stof(argv[++i]); + } else if (arg == "-so" || arg == "--score-offset") { + params.stability_score_offset = std::stof(argv[++i]); + } else if (arg == "-e" || arg == "--epsilon") { + params.eps = std::stof(argv[++i]); + } else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") { + params.eps_decoder_transformer = std::stof(argv[++i]); + } else if (arg == "-p" || arg == "--point-prompt") { + // TODO multiple points per model invocation + char* point = argv[++i]; + + char* coord = strtok(point, ","); + if (!coord){ + fprintf(stderr, "Error while parsing prompt!\n"); + exit(1); + } + params.pt.x = std::stof(coord); + coord = strtok(NULL, ","); + if (!coord){ + fprintf(stderr, "Error while parsing prompt!\n"); + exit(1); + } + params.pt.y = std::stof(coord); } else if (arg == "-h" || arg == "--help") { sam_print_usage(argc, argv, params); exit(0); @@ -2078,7 +2139,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!sam_model_load(params.model, model)) { + if (!sam_model_load(params, model)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } @@ -2147,10 +2208,11 @@ int main(int argc, char ** argv) { state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead()); state.allocr = ggml_allocr_new_measure(tensor_alignment); - // TODO: user input - const sam_point pt = { 414.375f, 162.796875f, }; + // TODO: more varied prompts + fprintf(stderr, "prompt: (%f, %f)\n", params.pt.x, params.pt.y); + // measure memory requirements for the graph - struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt); + struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt); if (!gf_measure) { fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__); return 1; @@ -2166,7 +2228,7 @@ int main(int argc, char ** argv) { // compute the graph with the measured exact memory requirements from above ggml_allocr_reset(state.allocr); - struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt); + struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt); if (!gf) { fprintf(stderr, "%s: failed to build fast graph\n", __func__); return 1; @@ -2182,7 +2244,7 @@ int main(int argc, char ** argv) { state.allocr = NULL; } - if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) { + if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out)) { fprintf(stderr, "%s: failed to write masks\n", __func__); return 1; }