Skip to content

Commit

Permalink
Anchor text (#15)
Browse files Browse the repository at this point in the history
* anchor text

* clean up notebook

* anchor text

* clean up notebook

* add anchor text

* update example notebook

* add build docs

* output example notebook

* update example

* change example title

* add logging and adjust UNK data type
  • Loading branch information
arnaudvl authored Mar 12, 2019
1 parent ca4bef0 commit 349ed2b
Show file tree
Hide file tree
Showing 20 changed files with 939 additions and 28 deletions.
29 changes: 29 additions & 0 deletions alibi/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
from io import BytesIO
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import tarfile
from typing import Tuple
from urllib.request import urlopen


def movie_sentiment() -> Tuple[list, list]:
"""
The movie review dataset, equally split between negative and positive reviews.
Returns
-------
Movie reviews and sentiment labels (0 means 'negative' and 1 means 'positive').
"""
url = 'http://www.cs.cornell.edu/People/pabo/movie-review-data/rt-polaritydata.tar.gz'
resp = urlopen(url)
tar = tarfile.open(fileobj=BytesIO(resp.read()), mode="r:gz")
data = []
labels = []
for i, member in enumerate(tar.getnames()[1:]):
f = tar.extractfile(member)
for line in f.readlines():
try:
line.decode('utf8')
except UnicodeDecodeError:
continue
data.append(line.decode('utf8').strip())
labels.append(i)
tar.close()
return data, labels


def adult(features_drop: list = ["fnlwgt", "Education-Num"]) -> Tuple[np.ndarray, np.ndarray, list, dict]:
Expand Down
4 changes: 3 additions & 1 deletion alibi/explainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
"""

from .anchor.anchor_tabular import AnchorTabular
from .anchor.anchor_text import AnchorText

__all__ = ["AnchorTabular"]
__all__ = ["AnchorTabular",
"AnchorText"]
25 changes: 15 additions & 10 deletions alibi/explainers/anchor/anchor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def make_tuples(previous_best: list, state: dict) -> list:
return list(new_tuples)

@staticmethod
def get_sample_fns(sample_fn: Callable, tuples: list, state: dict) -> list:
def get_sample_fns(sample_fn: Callable, tuples: list, state: dict, data_type: str = None) -> list:
"""
Parameters
----------
Expand All @@ -306,12 +306,13 @@ def get_sample_fns(sample_fn: Callable, tuples: list, state: dict) -> list:
List of anchor candidates
state
Dictionary with the relevant metrics like coverage and samples for candidate anchors
data_type
Data type for raw data
Returns
-------
List with sample functions for each candidate anchor.
"""

def complete_sample_fn(t: tuple, n: int) -> int:
"""
Parameters
Expand Down Expand Up @@ -339,8 +340,9 @@ def complete_sample_fn(t: tuple, n: int) -> int:
prealloc_size = state['prealloc_size']
current_idx = data.shape[0]
state['data'] = np.vstack((state['data'], np.zeros((prealloc_size, data.shape[1]), data.dtype)))
dtype = data_type if data_type is not None else raw_data.dtype
state['raw_data'] = np.vstack((state['raw_data'], np.zeros((prealloc_size, raw_data.shape[1]),
raw_data.dtype)))
dtype=dtype)))
state['labels'] = np.hstack((state['labels'], np.zeros(prealloc_size, labels.dtype)))
return labels.sum()

Expand Down Expand Up @@ -417,10 +419,10 @@ def get_anchor_from_tuple(t: tuple, state: dict) -> dict:

@staticmethod
def anchor_beam(sample_fn: Callable, delta: float = 0.05, epsilon: float = 0.1, batch_size: int = 10,
desired_confidence: float = 1, beam_size: int = 1,
verbose: bool = False, epsilon_stop: float = 0.05, min_samples_start: int = 0,
max_anchor_size: int = None, verbose_every: int = 1,
stop_on_first: bool = False, coverage_samples: int = 10000) -> dict:
desired_confidence: float = 1, beam_size: int = 1, verbose: bool = False,
epsilon_stop: float = 0.05, min_samples_start: int = 0, max_anchor_size: int = None,
verbose_every: int = 1, stop_on_first: bool = False, coverage_samples: int = 10000,
data_type: str = None) -> dict:
"""
Parameters
----------
Expand All @@ -446,6 +448,8 @@ def anchor_beam(sample_fn: Callable, delta: float = 0.05, epsilon: float = 0.1,
Whether to print intermediate output every verbose_every steps
stop_on_first
coverage_samples
data_type
Data type for raw data
Returns
-------
Expand Down Expand Up @@ -489,7 +493,8 @@ def anchor_beam(sample_fn: Callable, delta: float = 0.05, epsilon: float = 0.1,
prealloc_size = batch_size * 10000
current_idx = data.shape[0]
data = np.vstack((data, np.zeros((prealloc_size, data.shape[1]), data.dtype)))
raw_data = np.vstack((raw_data, np.zeros((prealloc_size, raw_data.shape[1]), raw_data.dtype)))
dtype = data_type if data_type is not None else raw_data.dtype
raw_data = np.vstack((raw_data, np.zeros((prealloc_size, raw_data.shape[1]), dtype=dtype)))
labels = np.hstack((labels, np.zeros(prealloc_size, labels.dtype)))
n_features = data.shape[1]
state = {'t_idx': collections.defaultdict(lambda: set()),
Expand Down Expand Up @@ -531,7 +536,7 @@ def anchor_beam(sample_fn: Callable, delta: float = 0.05, epsilon: float = 0.1,
# these functions sample randomly for all features except for the ones in the candidate anchors
# for the features in the anchor it uses the same category (categorical features) or samples from ...
# ... the same bin (discretized numerical features) as the feature in the observation that is explained
sample_fns = AnchorBaseBeam.get_sample_fns(sample_fn, tuples, state)
sample_fns = AnchorBaseBeam.get_sample_fns(sample_fn, tuples, state, data_type=dtype)

# for each tuple, get initial nb of samples used and prec(A)
initial_stats = AnchorBaseBeam.get_initial_statistics(tuples, state)
Expand Down Expand Up @@ -606,7 +611,7 @@ def anchor_beam(sample_fn: Callable, delta: float = 0.05, epsilon: float = 0.1,
tuples = []
for i in range(0, current_size):
tuples.extend(best_of_size[i])
sample_fns = AnchorBaseBeam.get_sample_fns(sample_fn, tuples, state)
sample_fns = AnchorBaseBeam.get_sample_fns(sample_fn, tuples, state, data_type=dtype)
initial_stats = AnchorBaseBeam.get_initial_statistics(tuples, state)
chosen_tuples = AnchorBaseBeam.lucb(sample_fns, initial_stats, epsilon,
delta, batch_size, 1, verbose=verbose)
Expand Down
Loading

0 comments on commit 349ed2b

Please sign in to comment.