-
Notifications
You must be signed in to change notification settings - Fork 835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add custom embedder #2236
Conversation
self, | ||
inputCol=None, | ||
outputCol=None, | ||
useTRTFlag=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: useTRTFlag -> runtime: "cpu", "gpu", "tensorrt", default cpu
|
||
# Define additional parameters | ||
useTRT = Param(Params._dummy(), "useTRT", "True if use TRT acceleration") | ||
driverOnly = Param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove driver Only code
inputCol="combined", | ||
outputCol="embeddings", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
look at other examples of proper defaults for these columns in library
for batch_size in [64, 32, 16, 8, 4, 2, 1]: | ||
for sentence_length in [20, 300, 512]: | ||
yield (batch_size, sentence_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make these magic numbers, parameters with defaults
""" | ||
Create a data loader with synthetic data using Faker. | ||
""" | ||
faker = Faker() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lets try to remove this dependency
for sentence_length in [20, 300, 512]: | ||
yield (batch_size, sentence_length) | ||
|
||
def get_dataloader(repeat_times: int = 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _get_dataloader
func, dataloader=tqdm(get_dataloader(), total=total_batches), config=conf | ||
) | ||
|
||
def run_on_driver(self, queries, spark): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise _
""" | ||
return self._defaultCopy(extra) | ||
|
||
def load_data_food_reviews(self, spark, path=None, limit=1000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this code into demo
class SuppressLogging: | ||
def __init__(self): | ||
self._original_stderr = None | ||
|
||
def start(self): | ||
"""Start suppressing logging by redirecting sys.stderr to /dev/null.""" | ||
if self._original_stderr is None: | ||
self._original_stderr = sys.stderr | ||
sys.stderr = open('/dev/null', 'w') | ||
|
||
def stop(self): | ||
"""Stop suppressing logging and restore sys.stderr.""" | ||
if self._original_stderr is not None: | ||
sys.stderr.close() | ||
sys.stderr = self._original_stderr | ||
self._original_stderr = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
FloatType, | ||
) | ||
|
||
class EmbeddingTransformer(Transformer, HasInputCol, HasOutputCol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: HuggingFaceSentenceEmbedder
Also name the file HuggingFaceSentenceEmbedder.py
modelName="intfloat/e5-large-v2", | ||
moduleName="e5-large-v2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no defaults here, and try to make this module Name thing go away
Initialize the EmbeddingTransformer with input/output columns and optional TRT flag. | ||
""" | ||
super(EmbeddingTransformer, self).__init__() | ||
self._setDefault( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try it on some other models from : https://sbert.net/docs/sentence_transformer/pretrained_models.html
tools/init_scripts/init_retriever.sh
Outdated
/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com cudf-cu11~=${RAPIDS_VERSION} cuml-cu11~=${RAPIDS_VERSION} pylibraft-cu11~=${RAPIDS_VERSION} rmm-cu11~=${RAPIDS_VERSION} | ||
|
||
# install model navigator | ||
/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com onnxruntime-gpu==1.16.3 "tensorrt==9.3.0.post12.dev1" "triton-model-navigator<1" "sentence_transformers~=2.2.2" "faker" "urllib3<2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove faker
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
Related Issues/PRs
#xxx
What changes are proposed in this pull request?
Briefly describe the changes included in this Pull Request.
How is this patch tested?
Does this PR change any dependencies?
Does this PR add a new feature? If so, have you added samples on website?
website/docs/documentation
folder.Make sure you choose the correct class
estimators/transformers
and namespace.DocTable
points to correct API link.yarn run start
to make sure the website renders correctly.<!--pytest-codeblocks:cont-->
before each python code blocks to enable auto-tests for python samples.WebsiteSamplesTests
job pass in the pipeline.