Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[iree-import-onnx] improve handling of large models #19217

Merged
merged 19 commits into from
Dec 9, 2024

Conversation

zjgarvey
Copy link
Contributor

@zjgarvey zjgarvey commented Nov 19, 2024

This pr adds a few options:

  1. --large-model allows disabling the onnx model checker if a user knows ahead of time that the model is too large. It will also not load the external weights in memory unless saving the parameters.
  2. --num-initializers-threshold allows storing initializers to the irpa file in batches with a specified number of entries. This can reduce the memory overhead of first gathering all of the initializers, then saving them to the irpa at once.
  3. --externalize-inputs-threshold allows converting inputs to externalized weights. This is useful for the following workflow: exporting a HF pytorch model with safetensors, saving a .irpa from the safetensor weights directly, and exporting to onnx with export_params=False and do_constant_folding=False (which converts weights to inputs and avoids folding weights with things like transposes). When importing to mlir, you can set externalize-inputs-threshold=<num_original_inputs> and it will convert the inputs from and beyond that threshold to util.global ops.
  4. --save-params/--no-save-params factors saving parameters out of import_initializer, and one can avoid saving parameters with --no-save-params. Useful for debugging compilation failures.

TODO:

Figure out what to do about loading the onnx model and updating opset version. It's possible to do opset version updating without weights in a somewhat hacky way, since models > 2GB fail on opset version updating.

Add documentation

…odel at all

also imports iree-runtime only when initializing the iree node importer and fails with better error messaging.

Signed-off-by: zjgarvey <[email protected]>
Signed-off-by: zjgarvey <[email protected]>
Signed-off-by: zjgarvey <[email protected]>
@zjgarvey zjgarvey requested a review from vinayakdsci November 25, 2024 21:49
@zjgarvey zjgarvey marked this pull request as ready for review December 3, 2024 16:33
@ScottTodd ScottTodd added the integrations/onnx ONNX integration work label Dec 3, 2024
Comment on lines 183 to 189
parser.add_argument(
"--large-model",
help="Setting this to true is recommended for large models."
" It will bypass loading external weights and running the onnx checker to determine the model size.",
action=argparse.BooleanOptionalAction,
default=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about the ergonomics here... I'd prefer if users don't need to specify a long list of arguments for common cases. We may be able to infer settings like this from the other parameters and input files.

Could do some heuristics-based check first, log when the model is too big, then add --force-enable-checks or --force-disable-checks arguments...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The biggest issue I have is that inferring the model size (with weights) is not simple to do with onnx.

I wanted to add this flag specifically so we could avoid running the checker, which is used simply to see if the model is >2GB in size. My issues with this checker is that it fails due to unrelated issues (e.g., duplicated metatdata_props) and that it actually hard crashes for some particularly large models.

The other reason for adding this flag is to allow running the importer without loading the external onnx weights until the moment we want to save a param file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The biggest issue I have is that inferring the model size (with weights) is not simple to do with onnx.

Can you check the file size? Not sure what I'm missing :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the weights are external, you need to check the size of those files too. However, there isn't like a metadata field saying where all of the external weights are stored, so you'd need to iterate through the initializers and op attributes of the model to find all external data references, add them to a set, then check the disk usage of each of those files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof. Aren't the files at least downloaded in a group?

Found some docs:

I'm missing a mental model for how users typically work with these "large" models 🤔 (2GB is not "large" in 2024 :P)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2GB is a frustrating limit imposed by protobuf.

It's only really problematic when trying to save a model without externalizing weights. A bit unrelated to the importer, but stuff like applying onnxruntime basic optimizations will attempt to save an optimized model, but sometimes the optimized model is larger than 2GB but the original isn't. You can apparently get the inference session to externalize the optimized model's weights, but how is one supposed to know when this is necessary?

In general, the checker is going to take up a lot of time when the model is large, and there isn't a point in running it if we already know the model is large. Also, if a model is very large, we should avoid loading model weights unless totally necessary.

In any case, if you find out a good way to say "hey onnx, how big is this model", please let me know and we can just replace the checker try/except block with an actual disk usage check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd need to go through some of the workflows myself with and without large external weights to form a good opinion. Can you point me to an example set of commands to run / models to try?

Here's another angle to think of this from: what happens if a user always sets --large-model? Are they missing out on some error checking? Will some optimizations not run? Are some features not supported? Based on the help text

        help="Setting this to true is recommended for large models."
        " It will bypass loading external weights and running the onnx checker to determine the model size.",

as a user, I might just set the flag no matter what... in which case how much are we gaining by making it default to off and including it at all?

Comment on lines 306 to 310
param_count += 1
if (
self.param_data.initializer_threshold
and param_count % self.param_data.initializer_threshold == 0
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the logic with % correct here?

    parser.add_argument(
        "--num-initializers-threshold",
        help="The maximum number of initializer tensors to be stored in-memory when creating a param archive.",
        type=int,
    )

2. --num-initializers-threshold allows storing initializers to the irpa file in batches with a specified number of entries. This can reduce the memory overhead of first gathering all of the initializers, then saving them to the irpa at once.

The code doesn't seem to line up with the docs or PR description, unless I'm misreading it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if the param_count % initializer_threshold == 0, then param_count is a multiple of the threshold, and we trigger saving the current param archive to the .irpa file. This will move in-memory tensors held in the param archive to the .irpa file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooooh I see. So like this?

param_archive = rt.ParameterIndex()
for i in xrange(0, len(self.globals)):
  param_archive.add_buffer(...)

  # Observed 'threshold' parameters already, flush from memory to disk.
  if i % initializer_threshold == 0:
    param_archive = param_archive.create_archive_file(...)

# Flush any remaining parameters to disk.
param_archive = param_archive.create_archive_file(...)

If so, a few observations:

  • Not all parameters are the same size, so a threshold based on the number of parameters could quite easily hit a degenerate case. Consider having 1000 parameters, with the first 10 being 1GB and the last 990 being 1KB. Flushing after even 2-3 would use substantial RAM at the start, then the flushing overhead would be slow for the long tail. Maybe threshold by memory used, not by parameter count?
  • What does the data flow actually look like under these python API calls? The function names don't make it obvious to me what the costs are.
  • This variable overwriting via param_archive = param_archive.create_archive_file() is sketchy on the surface... is the intent to append to a single file on disk? Or write to multiple files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is the result of some trial and error. I've basically run this on my laptop with another terminal checking memory usage and noticed flushing the parameters early takes more time but reduces the memory overhead. And I've definitely noticed that some batches are much larger than others, but I'm not sure how to set it up to flush at memory thresholds. I think maybe an good approach would be to count the number of elements in each param and multiply it by the bitwidth of the data type, then if adding a new param would exceed a bit threshold, then flush the archive before adding it. I'll try to mock this up in the next patch.

Calling create_archive_file seems to do the following (again, learned by trial and error):

  1. params that already exist in the irpa file, and which correspond to references in the param_archive, do not get modified.
  2. new params that exist in-memory get written to the irpa file.
  3. setting param_archive equal to the result seems to convert the param_archive to only store references to the irpa file, so when we add more params to this, it will then contain both the existing irpa file references in addition to the newly added in-memory params.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know what you think of the changes!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, except for

Calling create_archive_file seems to do the following (again, learned by trial and error)

I'd rather make the APIs explicit. It seems like you're relying on undocumented side effects here, so I'm worried that we're in https://www.hyrumslaw.com/ territory :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good point. I'm doing a bit more digging into the API (it's calling this:

.def(
"create_archive_file",
[](ParameterIndex &self, std::string file_path,
iree_io_physical_offset_t file_offset,
ParameterIndex *explicit_target_index) {
).

It might be better to hold two parameter indices: a target_index for the irpa file contents, and a running_index for in-memory params. I'll update when I know more

@ScottTodd ScottTodd self-requested a review December 6, 2024 17:55
Signed-off-by: zjgarvey <[email protected]>
@zjgarvey zjgarvey force-pushed the importer_large_models branch from 5333f05 to 152c04c Compare December 6, 2024 17:58
@zjgarvey zjgarvey changed the title [WIP] [iree-import-onnx] improve handling of large models [iree-import-onnx] improve handling of large models Dec 6, 2024
Comment on lines 183 to 189
parser.add_argument(
"--large-model",
help="Setting this to true is recommended for large models."
" It will bypass loading external weights and running the onnx checker to determine the model size.",
action=argparse.BooleanOptionalAction,
default=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd need to go through some of the workflows myself with and without large external weights to form a good opinion. Can you point me to an example set of commands to run / models to try?

Here's another angle to think of this from: what happens if a user always sets --large-model? Are they missing out on some error checking? Will some optimizations not run? Are some features not supported? Based on the help text

        help="Setting this to true is recommended for large models."
        " It will bypass loading external weights and running the onnx checker to determine the model size.",

as a user, I might just set the flag no matter what... in which case how much are we gaining by making it default to off and including it at all?

Comment on lines 306 to 310
param_count += 1
if (
self.param_data.initializer_threshold
and param_count % self.param_data.initializer_threshold == 0
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, except for

Calling create_archive_file seems to do the following (again, learned by trial and error)

I'd rather make the APIs explicit. It seems like you're relying on undocumented side effects here, so I'm worried that we're in https://www.hyrumslaw.com/ territory :P

@zjgarvey
Copy link
Contributor Author

zjgarvey commented Dec 7, 2024

@ScottTodd I'm thinking about the --large-model flag comment #19217 (comment), and the only downsides with using this flag:

  1. Creates temp files during file-based shape inference
  2. Doesn't allow updating opset version, since weights might not be loaded (external data references get wiped when updating opset version).

@ScottTodd
Copy link
Member

@ScottTodd I'm thinking about the --large-model flag comment #19217 (comment), and the only downsides with using this flag:

  1. Creates temp files during file-based shape inference
  2. Doesn't allow updating opset version, since weights might not be loaded (external data references get wiped when updating opset version).

Temp files are fine (within reason).

Opset version updating is pretty important for older models and unit tests, but I suspect it will be less important for large models.

@zjgarvey
Copy link
Contributor Author

zjgarvey commented Dec 9, 2024

@ScottTodd

After looking at the API more carefully, it looks like param_archive = param_archive.create_archive_file(path) is rather inefficient, since it needs to resize the file and write both the old contents and new contents in all over again. Doing this repeatedly is quite expensive.

Instead, I decided to save params to smaller temp files, recording the saved params to an accumulating target_index each time. Once all temp files are generated, it will save the target_index to the specified param_path, which will transfer all the data from each of those temp files to the irpa at param_path. This only takes about twice as long to do, compared to not specifying a param-gb-threshold, since the data needs to be written twice. It would be faster to just save the multiple .irpa files, but this would make it really annoying to specify params in iree-run-module. I've updated the description of that arg to mention that it is only recommended for machines with low RAM.

I've also updated a few other descriptions based on the conversations we've had.

Would you mind taking another look at the recent changes?

@zjgarvey zjgarvey requested a review from ScottTodd December 9, 2024 21:30
Copy link
Member

@ScottTodd ScottTodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this LGTM now. Thanks for your patience - user facing API changes deserve extra review focus since changing or removing those APIs later comes at a high cost compared to changing internal implementation details.

@zjgarvey zjgarvey merged commit ab3c9bb into iree-org:main Dec 9, 2024
38 checks passed
@zjgarvey
Copy link
Contributor Author

zjgarvey commented Dec 9, 2024

@ScottTodd I really appreciate the time and attention that you put into reviewing this (and other PR's). Thanks for all the helpful feedback.

@ScottTodd
Copy link
Member

TODO:

Add documentation

BTW, this will be worth calling out in release notes at #19192. Bonus points if the docs at https://iree.dev/guides/ml-frameworks/onnx/ are updated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
integrations/onnx ONNX integration work
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants