-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache model loading in model card #299
Conversation
Thank you for taking this issue. I'm unsure about the implementation here, which I admit is not a trivial matter. If I understand correctly, you're using an To me, this seems to be a bit "hacky" and I would like to suggest a different approach. In my suggestion, we would leave I did a quick and dirty implementation of how it could look like (in ...
from hashlib import sha256
from functools import cached_property
...
class Card:
def __init__(...):
...
self._model_hash = ""
self._populate_template()
def get_model(self) -> Any:
"""..."""
if isinstance(self.model, (str, Path)) and hasattr(self, "_model"):
hash_obj = sha256()
buf_size = 2 ** 20 # load in chunks to save memory
with open(self.model, "rb") as f:
for chunk in iter(lambda: f.read(buf_size), b""):
hash_obj.update(chunk)
model_hash = hash_obj.hexdigest()
# if hash changed, invalidate cache by deleting attribute
if model_hash != self._model_hash:
del self._model
self._model_hash = model_hash
return self._model
@cached_property
def _model(self):
model = _load_model(self.model, self.trusted)
return model What do you think about that? Of course, it would require some comments and tests, but I hope you get the general idea.
It's okay, the diff is still shown correctly, right? However, please make sure to correct this for the next PR. |
I understand your approach and I agree that it fits better conceptually. I took your implementation because it works out of the box, no errors. I wrote a test for it for which I would appreciate some feedback. Thanks a lot for the help on this PR! |
Thanks for the updates. I haven't done a proper review yet, but I saw that some changes were unrelated to the additions of this PR. Could you please clean those up? Maybe those were changed by your IDE automatically? Also, it seems that there are Finally, the docs are not building. I think it's the same issue as in #207, so whatever fixes that should work here too. |
Thank you for your comments. I'm in the process of fixing the docs error, I have asked a question about that in #207 |
There's a merge conflict here, and the CI hasn't run completely somehow, could you please merge with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the PR. It still shows a lot of unrelated changes in the diff, could you please remove them?
Regarding the test, I have to admit I don't quite understand it. For example, what does this test?
assert str(card._model_hash) == card.__dict__["_model_hash"]
I think what I would like to see is a more high level test, i.e. nothing that involves any hashes, since those are implementation details. One way would be to mock _load_model
and assert that, when card.get_model()
is called, _load_model
is only called once, the first time, and after that it's not called anymore. Then, only when the underlying model is overwritten, should it be called again. WDYT?
Hey! Sorry it's been a while. Now that I look back at the test I wrote I think that the line you highlighted is not testing anything, I'm not sure what I was thinking at that time. I agree with your proposal to check that everything works rather than checking the details as I was trying to do. I'm happy to update the test. I'm not sure how to get rid of all of the unrelated changes in the diff. I have modified 3 files of the 7, the other changes were applied after I ran pre-commit manually on all files. I will merge with |
Okay, thanks for clearing that up.
Let's see if that works. Otherwise, in the worst case, you could try opening a new PR based on the latest |
If we go about mocking |
Are you sure that we need |
…nto cache-model-loading
Ah okay! I have implemented it using I have also got rid of the files in the diff that weren't supposed to be there. I reverted the changes of those files. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done with reverting the changes, this is now much easier to review, thanks.
There isn't much work left. Regarding the test, I think it can be improved a bit, please take a look at my suggestion. Other than that, please add an entry to docs/changes.rst
. Then this should be good to go.
The failing CI job is unrelated to this PR, so please ignore it.
skops/card/tests/test_card.py
Outdated
# _load_model get called | ||
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) | ||
with mock.patch("skops.card._model_card._load_model") as mock_load_model: | ||
model1 = card.get_model() | ||
model2 = card.get_model() | ||
assert model1 is model2 | ||
# model is cached, hence _load_model is not called | ||
mock_load_model.assert_not_called() | ||
# update card with new model | ||
new_model = LogisticRegression() | ||
_, save_file = save_model_to_file(new_model, ".skops") | ||
del card.model | ||
card.model = save_file | ||
model3 = card.get_model() # model gets cached | ||
model4 = card.get_model() | ||
assert model3 is model4 | ||
assert mock_load_model.call_count == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the intent with this test, but I think it's problematic that del card.model
and card.model = save_file
are being used. As a skops user, I wouldn't do that and I would still expect the cached model loading to work correctly. Therefore, I made some changes to the test so that these lines are not needed:
# _load_model get called | |
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) | |
with mock.patch("skops.card._model_card._load_model") as mock_load_model: | |
model1 = card.get_model() | |
model2 = card.get_model() | |
assert model1 is model2 | |
# model is cached, hence _load_model is not called | |
mock_load_model.assert_not_called() | |
# update card with new model | |
new_model = LogisticRegression() | |
_, save_file = save_model_to_file(new_model, ".skops") | |
del card.model | |
card.model = save_file | |
model3 = card.get_model() # model gets cached | |
model4 = card.get_model() | |
assert model3 is model4 | |
assert mock_load_model.call_count == 1 | |
new_model = LogisticRegression(random_state=4321) | |
# mock _load_model, it still loads the model but we can track call count | |
mock_load_model = mock.Mock(side_effect=load) | |
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) | |
with mock.patch("skops.card._model_card._load_model", mock_load_model): | |
model1 = card.get_model() | |
model2 = card.get_model() | |
assert model1 is model2 | |
# model is cached, hence _load_model is not called | |
mock_load_model.assert_not_called() | |
# override model with new model | |
dump(new_model, card.model) | |
model3 = card.get_model() | |
assert mock_load_model.call_count == 1 | |
assert model3.random_state == 4321 | |
model4 = card.get_model() | |
assert model3 is model4 | |
assert mock_load_model.call_count == 1 # cached call |
(line 3: load
needs to be imported from skops.io
)
This test is similar to yours but is closer to how a user would actually use the model card. Please take a look and see if you agree with me. It would also be good to have a comment at the start of the test to explain what is being tested here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see how the user would go with your approach first, rather than mine. Definitely, I agree with your suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested your suggestion and it passes the test as expected. I also added a short comment to describe what the function tests at the beginning of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. This LGTM. @adrinjalali not sure if you want to review too, if not feel free to merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a nit, otherwise LGTM.
skops/card/_model_card.py
Outdated
@@ -554,19 +557,31 @@ def _populate_template(self, model_diagram: bool | Literal["auto"] | str): | |||
|
|||
def get_model(self) -> Any: | |||
"""Returns sklearn estimator object. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please put back this line
skops/card/_model_card.py
Outdated
If the ``model`` is already loaded, return it as is. If the ``model`` | ||
attribute is a ``Path``/``str``, load the model and return it. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted them! sorry about that, I will pay attention to that next time.
Implementation of cache model loading discussed in issue #243
This PR includes the following changes:
_load_model
function to implement model cachinghash_model
that is used as a decorator on_load_model
functionlru_cache
decorator to implement caching on top ofhash_model
and_load_model
functionstest_load_model
to test for cache model loadingtest_hash_model
test implemented to testhash_model
function