Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache model loading in model card #299

Merged
merged 40 commits into from
May 9, 2023

Conversation

jucamohedano
Copy link
Contributor

Implementation of cache model loading discussed in issue #243

This PR includes the following changes:

  • Add new argument to _load_model function to implement model caching
  • New function hash_model that is used as a decorator on _load_model function
  • Use of lru_cache decorator to implement caching on top of hash_model and _load_model functions
  • Extended test_load_model to test for cache model loading
  • test_hash_model test implemented to test hash_model function

@jucamohedano jucamohedano changed the title Cache model loading Cache model loading in model card Feb 12, 2023
@BenjaminBossan
Copy link
Collaborator

Thank you for taking this issue. I'm unsure about the implementation here, which I admit is not a trivial matter. If I understand correctly, you're using an lru_cache on _load_model and then hash_model to insert an argument into the function's list of args that is actually the hash of the function.

To me, this seems to be a bit "hacky" and I would like to suggest a different approach. In my suggestion, we would leave _load_model as is and instead do the caching on the Card object. I think it fits better there conceptually -- the cached model belongs to a specific card, whereas the _load_model function could be used in different contexts (e.g. imagine someone working on two model cards in the same program).

I did a quick and dirty implementation of how it could look like (in _model_card.py):

...
from hashlib import sha256
from functools import cached_property
...

class Card:
    def __init__(...):
        ...
        self._model_hash = ""

        self._populate_template()

    def get_model(self) -> Any:
        """..."""
        if isinstance(self.model, (str, Path)) and hasattr(self, "_model"):
            hash_obj = sha256()
            buf_size = 2 ** 20  # load in chunks to save memory
            with open(self.model, "rb") as f:
                for chunk in iter(lambda: f.read(buf_size), b""):
                    hash_obj.update(chunk)
            model_hash = hash_obj.hexdigest()

            # if hash changed, invalidate cache by deleting attribute
            if model_hash != self._model_hash:
                del self._model
                self._model_hash = model_hash

        return self._model

    @cached_property
    def _model(self):
        model = _load_model(self.model, self.trusted)
        return model

What do you think about that? Of course, it would require some comments and tests, but I hope you get the general idea.

However, there's a small problem. I worked on a previous PR 207 which is supposed to be merged, but I worked on this PR in the main branch of my fork (my mistake). Therefore, I had to branch off from main to implement cache model loading and now all my previous commits are attached to 299. I'm sorry if that's an issue.

It's okay, the diff is still shown correctly, right? However, please make sure to correct this for the next PR.

@jucamohedano
Copy link
Contributor Author

I understand your approach and I agree that it fits better conceptually. I took your implementation because it works out of the box, no errors. I wrote a test for it for which I would appreciate some feedback. Thanks a lot for the help on this PR!

@BenjaminBossan
Copy link
Collaborator

Thanks for the updates. I haven't done a proper review yet, but I saw that some changes were unrelated to the additions of this PR. Could you please clean those up? Maybe those were changed by your IDE automatically?

Also, it seems that there are black errors, could you please set up the pre-commit hooks as described here?

Finally, the docs are not building. I think it's the same issue as in #207, so whatever fixes that should work here too.

@jucamohedano
Copy link
Contributor Author

Thank you for your comments. I'm in the process of fixing the docs error, I have asked a question about that in #207

@adrinjalali
Copy link
Member

There's a merge conflict here, and the CI hasn't run completely somehow, could you please merge with upstream/main and push again?

Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR. It still shows a lot of unrelated changes in the diff, could you please remove them?

Regarding the test, I have to admit I don't quite understand it. For example, what does this test?

assert str(card._model_hash) == card.__dict__["_model_hash"]

I think what I would like to see is a more high level test, i.e. nothing that involves any hashes, since those are implementation details. One way would be to mock _load_model and assert that, when card.get_model() is called, _load_model is only called once, the first time, and after that it's not called anymore. Then, only when the underlying model is overwritten, should it be called again. WDYT?

@jucamohedano
Copy link
Contributor Author

Hey! Sorry it's been a while.

Now that I look back at the test I wrote I think that the line you highlighted is not testing anything, I'm not sure what I was thinking at that time.

I agree with your proposal to check that everything works rather than checking the details as I was trying to do. I'm happy to update the test.

I'm not sure how to get rid of all of the unrelated changes in the diff. I have modified 3 files of the 7, the other changes were applied after I ran pre-commit manually on all files. I will merge with upstream/main again and see if I can get rid of them.

@BenjaminBossan
Copy link
Collaborator

Now that I look back at the test I wrote I think that the line you highlighted is not testing anything, I'm not sure what I was thinking at that time.

I agree with your proposal to check that everything works rather than checking the details as I was trying to do. I'm happy to update the test.

Okay, thanks for clearing that up.

I'm not sure how to get rid of all of the unrelated changes in the diff. I have modified 3 files of the 7, the other changes were applied after I ran pre-commit manually on all files. I will merge with upstream/main again and see if I can get rid of them.

Let's see if that works. Otherwise, in the worst case, you could try opening a new PR based on the latest main with the same changes, that should hopefully work. For me, it's not quite clear if merging the PR would also merge those unrelated diffs (which we want to avoid) or if they're just displayed on GitHub but merging will not actually change those lines.

@jucamohedano
Copy link
Contributor Author

If we go about mocking _load_model function we have to install the pytest-mock plugin. Would that be okay @BenjaminBossan ? I have written the test with pytest-mock in case that's fine

@BenjaminBossan
Copy link
Collaborator

If we go about mocking _load_model function we have to install the pytest-mock plugin. Would that be okay @BenjaminBossan ? I have written the test with pytest-mock in case that's fine

Are you sure that we need pytest-mock? Personally, I never needed that package, as unittest.mock from the standard library and pytest's monkeypatch were more than enough for my needs. If you search for "mock" and "patch" in skops, you'll find a couple of examples.

@jucamohedano
Copy link
Contributor Author

Ah okay! I have implemented it using unittest.mock and I think it's enough with it. Let me know what you think.

I have also got rid of the files in the diff that weren't supposed to be there. I reverted the changes of those files.

Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done with reverting the changes, this is now much easier to review, thanks.

There isn't much work left. Regarding the test, I think it can be improved a bit, please take a look at my suggestion. Other than that, please add an entry to docs/changes.rst. Then this should be good to go.

The failing CI job is unrelated to this PR, so please ignore it.

Comment on lines 152 to 168
# _load_model get called
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path))
with mock.patch("skops.card._model_card._load_model") as mock_load_model:
model1 = card.get_model()
model2 = card.get_model()
assert model1 is model2
# model is cached, hence _load_model is not called
mock_load_model.assert_not_called()
# update card with new model
new_model = LogisticRegression()
_, save_file = save_model_to_file(new_model, ".skops")
del card.model
card.model = save_file
model3 = card.get_model() # model gets cached
model4 = card.get_model()
assert model3 is model4
assert mock_load_model.call_count == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the intent with this test, but I think it's problematic that del card.model and card.model = save_file are being used. As a skops user, I wouldn't do that and I would still expect the cached model loading to work correctly. Therefore, I made some changes to the test so that these lines are not needed:

Suggested change
# _load_model get called
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path))
with mock.patch("skops.card._model_card._load_model") as mock_load_model:
model1 = card.get_model()
model2 = card.get_model()
assert model1 is model2
# model is cached, hence _load_model is not called
mock_load_model.assert_not_called()
# update card with new model
new_model = LogisticRegression()
_, save_file = save_model_to_file(new_model, ".skops")
del card.model
card.model = save_file
model3 = card.get_model() # model gets cached
model4 = card.get_model()
assert model3 is model4
assert mock_load_model.call_count == 1
new_model = LogisticRegression(random_state=4321)
# mock _load_model, it still loads the model but we can track call count
mock_load_model = mock.Mock(side_effect=load)
card = Card(iris_skops_file, metadata=metadata_from_config(destination_path))
with mock.patch("skops.card._model_card._load_model", mock_load_model):
model1 = card.get_model()
model2 = card.get_model()
assert model1 is model2
# model is cached, hence _load_model is not called
mock_load_model.assert_not_called()
# override model with new model
dump(new_model, card.model)
model3 = card.get_model()
assert mock_load_model.call_count == 1
assert model3.random_state == 4321
model4 = card.get_model()
assert model3 is model4
assert mock_load_model.call_count == 1 # cached call

(line 3: load needs to be imported from skops.io)

This test is similar to yours but is closer to how a user would actually use the model card. Please take a look and see if you agree with me. It would also be good to have a comment at the start of the test to explain what is being tested here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see how the user would go with your approach first, rather than mine. Definitely, I agree with your suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested your suggestion and it passes the test as expected. I also added a short comment to describe what the function tests at the beginning of it.

Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. This LGTM. @adrinjalali not sure if you want to review too, if not feel free to merge.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nit, otherwise LGTM.

@@ -554,19 +557,31 @@ def _populate_template(self, model_diagram: bool | Literal["auto"] | str):

def get_model(self) -> Any:
"""Returns sklearn estimator object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put back this line

If the ``model`` is already loaded, return it as is. If the ``model``
attribute is a ``Path``/``str``, load the model and return it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted them! sorry about that, I will pay attention to that next time.

@adrinjalali adrinjalali merged commit 8a58101 into skops-dev:main May 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants