Skip to content

Commit

Permalink
[CLIP]: Zeroshot Pipeline (#1098)
Browse files Browse the repository at this point in the history
* initial refactor

* move BasePipeline to a new file

* test fix

* anothe test fix

* fix import

* revert

* initial refactor

* add tests for BasePipeline

* move BasePipeline to a new file

* initial refactor

* update test; finish off initial refactoring changes post local testing

* initial commit for clip zero-shot

* add basic structure for text branch and zeroshot

* add schema details

* update pipelines after running mock engine tests

* add zeroshot tests

* rebase fix

* clean-up comments; add note about onnx export issue

* add clip dependency

* move paths to fixtures

* rebase fix

* rebase fix

* refactor pipelines to separate visual, text, and zeroshot. also add pytest skips until model issues are resolved

* make zershot arguments explicit; deal with quality
:

* update workflow to install clip for base test

* update pipelines after using MLR's zeroshot models

* add readme with examples, update setup.py and clean-up return types

* quality fix

* Update visual_pipeline.py

update model loading

Co-authored-by: dbogunowicz <[email protected]>

* add docstring

* move docstring; add params

* fix rebase

* quality
  • Loading branch information
dsikka authored Aug 2, 2023
1 parent 3254ca8 commit ed9b1ee
Show file tree
Hide file tree
Showing 9 changed files with 562 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- name: ⚙️ Install dependencies
run: pip3 install .[dev,server,image_classification,transformers] opencv-python
run: pip3 install .[dev,server,image_classification,transformers,clip] opencv-python
- name: Run base tests
run: make test
cli-smoke-tests:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _parse_requirements_file(file_path):
"haystack_reqs.txt",
)
_haystack_integration_deps = _parse_requirements_file(_haystack_requirements_file_path)

_clip_deps = ["open_clip_torch==2.20.0", "scipy==1.10.1"]

_torch_deps = ["torch>=1.7.0,<=2.0"]

Expand Down Expand Up @@ -280,6 +280,7 @@ def _setup_extras() -> Dict:
"yolov8": _yolov8_integration_deps,
"transformers": _transformers_integration_deps,
"torch": _torch_deps,
"clip": _clip_deps,
}


Expand Down
75 changes: 75 additions & 0 deletions src/deepsparse/clip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# CLIP Inference Pipelines

DeepSparse allows inference on [CLIP](https://github.com/mlfoundations/open_clip) models.

The CLIP integration currently supports the following task:
- **Zero-shot Image Classification** - Classifying images given possible classes

## Getting Started

Before you start your adventure with the DeepSparse Engine, make sure that your machine is compatible with our [hardware requirements](https://docs.neuralmagic.com/deepsparse/source/hardware.html).

### Installation
```pip install deepsparse[clip]```

### Model Format
By default, to deploy CLIP models using the DeepSparse Engine, it is required to supply the model in the ONNX format. This grants the engine the flexibility to serve any model in a framework-agnostic environment. To see examples of pulling CLIP models and exporting them to ONNX, please see the [sparseml documentation](https://github.com/neuralmagic/sparseml/tree/main/integrations/clip). For the Zero-shot image classification workflow, two ONNX models are required, a visual model for CLIP's visual branch, and a text model for CLIP's text branch. Both of these model should be produced through the sparseml integration linked above.

### Deployment examples:
The following example uses pipelines to run the CLIP models for inference. As input, the pipeline ingests a list of images and a list of possible classes. A class is returned for each of the provided images.

If you don't have images ready, pull down the sample images using the following commands:

```bash
wget -O basilica.jpg https://raw.githubusercontent.com/neuralmagic/deepsparse/main/src/deepsparse/yolo/sample_images/basilica.jpg

wget -O buddy.jpeg https://raw.githubusercontent.com/neuralmagic/deepsparse/main/tests/deepsparse/pipelines/sample_images/buddy.jpeg
```

This will pull down two images, one with a happy dog and one with St.Peter's basilica.

#### Zero-shot Prediction

Let's run an example to clasify the images. We'll provide the images in a list with their file names as well as a list of possible classes. We'll also provide paths to the exported ONNX models.

```python
import numpy as np

from deepsparse import BasePipeline
from deepsparse.clip import (
CLIPTextInput,
CLIPVisualInput,
CLIPZeroShotInput
)

possible_classes = ["ice cream", "an elephant", "a dog", "a building", "a church"]
images = ["basilica.jpg", "buddy.jpeg"]

model_path_text = "zeroshot_research/text/model.onnx"
model_path_visual = "zeroshot_research/visual/model.onnx"

kwargs = {
"visual_model_path": model_path_visual,
"text_model_path": model_path_text,
}
pipeline = BasePipeline.create(task="clip_zeroshot", **kwargs)

pipeline_input = CLIPZeroShotInput(
image=CLIPVisualInput(images=images),
text=CLIPTextInput(text=possible_classes),
)

output = pipeline(pipeline_input).text_scores
for i in range(len(output)):
prediction = possible_classes[np.argmax(output[i])]
print(f"Image {images[i]} is a picture of {prediction}")
```

Running the code above, we get the following outuput:

```
DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230727 COMMUNITY | (3cb4a3e5) (optimized) (system=avx2, binary=avx2)
Image basilica.jpg is a picture of a church
Image buddy.jpeg is a picture of a dog
```
31 changes: 31 additions & 0 deletions src/deepsparse/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from deepsparse.clip.text_pipeline import (
CLIPTextInput,
CLIPTextOutput,
CLIPTextPipeline,
)
from deepsparse.clip.visual_pipeline import (
CLIPVisualInput,
CLIPVisualOutput,
CLIPVisualPipeline,
)
from deepsparse.clip.zeroshot_pipeline import (
CLIPZeroShotInput,
CLIPZeroShotOutput,
CLIPZeroShotPipeline,
)
19 changes: 19 additions & 0 deletions src/deepsparse/clip/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


__all__ = ["CLIP_RGB_MEANS", "CLIP_RGB_STDS"]

CLIP_RGB_MEANS = [0.48145466, 0.4578275, 0.40821073]
CLIP_RGB_STDS = [0.26862954, 0.26130258, 0.27577711]
102 changes: 102 additions & 0 deletions src/deepsparse/clip/text_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Type, Union

import numpy as np
from pydantic import BaseModel, Field

from deepsparse.pipeline import Pipeline
from deepsparse.utils import model_to_path
from open_clip.tokenizer import tokenize


__all__ = ["CLIPTextInput", "CLIPTextOutput", "CLIPTextPipeline"]


class CLIPTextInput(BaseModel):
"""
Input for the CLIP Text Branch
"""

text: Union[str, List[str]] = Field(description="List of text to process")


class CLIPTextOutput(BaseModel):
"""
Output for the CLIP Text Branch
"""

text_embeddings: List[Any] = Field(
description="Text embeddings for the single text or list of embeddings for "
"multiple."
)


@Pipeline.register(task="clip_text", default_model_path=None)
class CLIPTextPipeline(Pipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.tokenizer = tokenize

@property
def input_schema(self) -> Type[CLIPTextInput]:
"""
:return: pydantic model class that inputs to this pipeline must comply to
"""
return CLIPTextInput

@property
def output_schema(self) -> Type[CLIPTextOutput]:
"""
:return: pydantic model class that inputs to this pipeline must comply to
"""
return CLIPTextOutput

def setup_onnx_file_path(self):
"""
Performs any setup to unwrap and process the given `model_path` and other
class properties into an inference ready onnx file to be compiled by the
engine of the pipeline
:return: file path to the ONNX file for the engine to compile
"""
return model_to_path(self.model_path)

def process_inputs(self, inputs: CLIPTextInput) -> List[np.ndarray]:
"""
Preprocess inputs for CLIP's Trext Branch to comply with the DeepSparse Engine
:param inputs: CLITextInput
:return: list of preprocessed numpy arrays
"""
if isinstance(inputs.text, str):
inputs.text = [inputs.text]

tokens = self.tokenizer(inputs.text)
tokens = [np.array(t).astype(np.int32) for t in tokens]
tokens = np.stack(tokens, axis=0)
return [tokens]

def process_engine_outputs(
self, engine_outputs: List[np.array], **kwargs
) -> CLIPTextOutput:
"""
:param engine_outputs: list of numpy arrays that are the output of the engine
forward pass
:return: outputs of engine post-processed into an object in the `output_schema`
format of this pipeline
"""
return self.output_schema(text_embeddings=engine_outputs)
Loading

0 comments on commit ed9b1ee

Please sign in to comment.