Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 23, 2024
1 parent 7025944 commit 8999d1d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 18 deletions.
1 change: 0 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,6 @@ def minhash(
assert isinstance(num_hashes, int)
assert isinstance(ngram_size, int)
assert isinstance(seed, int)
assert isinstance(hash_function, str)
assert isinstance(hash_function, native.HashFunctionKind), f"Hash function {hash_function} not found"

return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function))
Expand Down
4 changes: 1 addition & 3 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,10 @@ def minhash(
raise ValueError(f"expected an integer for ngram_size but got {type(ngram_size)}")
if seed is not None and not isinstance(seed, int):
raise ValueError(f"expected an integer or None for seed but got {type(seed)}")
if not isinstance(hash_function, str):
raise ValueError(f"expected a string for hash_function but got {type(hash_function)}")
if not isinstance(hash_function, HashFunctionKind):
raise ValueError(f"expected HashFunctionKind for hash_function but got {type(hash_function)}")

return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed))
return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed, hash_function))

def _to_str_values(self) -> Series:
return Series._from_pyseries(self._series.to_str_values())
Expand Down
17 changes: 3 additions & 14 deletions tests/series/test_minhash.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
from __future__ import annotations

from enum import Enum

import pytest

from daft import DataType, Series


class HashFunctionKind(Enum):
"""
Kind of hash function to use for minhash.
"""

MurmurHash3 = 0
XxHash = 1
Sha1 = 2
from daft.daft import HashFunctionKind


def minhash_none(
Expand All @@ -25,9 +14,9 @@ def minhash_none(
hash_function: HashFunctionKind,
) -> list[list[int] | None]:
if seed is None:
return series.minhash(num_hashes, ngram_size, hash_function=hash_function.name.lower()).to_pylist()
return series.minhash(num_hashes, ngram_size, hash_function=hash_function).to_pylist()
else:
return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function.name.lower()).to_pylist()
return series.minhash(num_hashes, ngram_size, seed, hash_function=hash_function).to_pylist()


test_series = Series.from_pylist(
Expand Down

0 comments on commit 8999d1d

Please sign in to comment.