Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cnn influence example #195

Merged
merged 31 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
93ec7f2
WIP influence example using imagenet and resnet
Xuzzo Nov 8, 2022
7b73a98
WIP extending example of cnn with ifs
Xuzzo Nov 9, 2022
17a45e9
WIP notebook
Xuzzo Nov 11, 2022
7990df1
WIP still working on damn example
Xuzzo Nov 11, 2022
b8df585
WIP writing docs for notebook
Xuzzo Nov 13, 2022
4800233
WIP writing theory in example
Xuzzo Nov 23, 2022
9c62032
Merge branch 'develop' into cnn_influence_example
Xuzzo Dec 6, 2022
7ada9a6
add theory appendix to imagenet notebook
Xuzzo Dec 6, 2022
1f36c05
moved methods to notebook support
Xuzzo Dec 6, 2022
c6c6da8
update to imagenet notebook - docs to notebook support methods
Xuzzo Dec 7, 2022
d203ad4
cosmetic changes to notebook
Xuzzo Dec 9, 2022
7d38b89
trying to solve tox issue in ci
Xuzzo Dec 9, 2022
b70432d
fix ci data loading in notebook
Xuzzo Dec 9, 2022
b16287e
minor changes to docs
Xuzzo Dec 9, 2022
bcc745a
update changelog and fix typing in notebook_support
Xuzzo Dec 12, 2022
cb18bdf
including progress bar
Xuzzo Dec 19, 2022
f371af8
addressing MR comments
Xuzzo Dec 20, 2022
0c7f466
dummy commit to re-trigger pipelines
Xuzzo Dec 21, 2022
d156f6e
minor changes to notebooks
Xuzzo Dec 21, 2022
eb63458
add sphinx hidden to cells
Xuzzo Dec 22, 2022
461eb27
minor changes to notebook
Xuzzo Dec 22, 2022
b4793f3
Forward calls to wrapped torch model
mdbenito Dec 27, 2022
f059b6b
Return ndarrays in TorchModel.fit
mdbenito Dec 27, 2022
d193eec
Types, strings, etc.
mdbenito Dec 27, 2022
826ae08
Remove clutter from notebook, some refactoring, rephrasing and tweaking.
mdbenito Dec 27, 2022
1e9e0a3
git ignore saved models
mdbenito Dec 27, 2022
c7256fc
Merge pull request #235 from appliedAI-Initiative/fix/cnn-influence
Xuzzo Dec 28, 2022
170970d
dummy commit to trigger pipeline
Xuzzo Dec 28, 2022
a31b263
fix typing
Xuzzo Dec 28, 2022
2d3a719
minor changes to notebooks
Xuzzo Dec 28, 2022
9f704d0
add req-notebooks and remove InternalDataset
Xuzzo Dec 28, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
- **Breaking change:** Introduces a class ValuationResult to gather and inspect
results from all valuation algorithms
[PR #214](https://github.com/appliedAI-Initiative/pyDVL/pull/214)
- Fixes bug in Influence calculation with multi-dimensional input and adds
new example notebook
[PR #195](https://github.com/appliedAI-Initiative/pyDVL/pull/195)

## 0.3.0 - 💥 Breaking changes

Expand Down
1,271 changes: 1,271 additions & 0 deletions notebooks/influence_imagenet.ipynb

Large diffs are not rendered by default.

151 changes: 111 additions & 40 deletions notebooks/influence_wine.ipynb

Large diffs are not rendered by default.

348 changes: 339 additions & 9 deletions notebooks/notebook_support.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from cloudpickle import pickle as pkl
from PIL.JpegImagePlugin import JpegImageFile

from pydvl.influence.model_wrappers.torch_wrappers import TorchModel
from pydvl.utils import Dataset

try:
import torch

_TORCH_INSTALLED = True
except ImportError:
_TORCH_INSTALLED = False

if TYPE_CHECKING:
from numpy.typing import NDArray

imgnet_model_data_path = Path().resolve().parent / "data/imgnet_model"


def plot_dataset(
train_ds: Tuple["NDArray", "NDArray"],
test_ds: Tuple["NDArray", "NDArray"],
x_min: Optional["NDArray"] = None,
x_max: Optional["NDArray"] = None,
train_ds: Tuple["NDArray[np.float_]", "NDArray[np.int_]"],
test_ds: Tuple["NDArray[np.float_]", "NDArray[np.int_]"],
x_min: Optional["NDArray[np.float_]"] = None,
x_max: Optional["NDArray[np.float_]"] = None,
*,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
legend_title: Optional[str] = None,
vline: Optional[float] = None,
line: Optional["NDArray"] = None,
line: Optional["NDArray[np.float_]"] = None,
suptitle: Optional[str] = None,
s: Optional[float] = None,
figsize: Tuple[int, int] = (20, 10),
Expand Down Expand Up @@ -95,15 +110,15 @@ def plot_dataset(


def plot_influences(
x: "NDArray",
influences: "NDArray",
x: "NDArray[np.float_]",
influences: "NDArray[np.float_]",
corrupted_indices: Optional[List[int]] = None,
*,
ax: Optional[plt.Axes] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
legend_title: Optional[str] = None,
line: Optional["NDArray"] = None,
line: Optional["NDArray[np.float_]"] = None,
suptitle: Optional[str] = None,
colorbar_limits: Optional[Tuple] = None,
) -> plt.Axes:
Expand Down Expand Up @@ -258,3 +273,318 @@ def _handle_legend(scatter):
edgecolors="r",
s=80,
)


def load_preprocess_imagenet(
train_size: float,
test_size: float,
downsample_ds_to_fraction: float = 1,
keep_labels: Optional[List] = None,
random_state: Optional[int] = None,
is_CI: bool = False,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Loads the tiny imagened dataset from huggingface and preprocesses it
for model input.

:param train_size: fraction of indices to use for training
:param test_size: fraction of data to use for testing
:param downsample_ds_to_fraction: which fraction of the full dataset to keep. \
E.g. downsample_ds_to_fraction=0.2 only 20% of the dataset is kept
:param keep_labels: which of the original labels to keep. \
E.g. keep_labels=[10,20] only returns the images with labels 10 and 20.
:param random_state: Random state. Fix this for reproducibility of sampling.
:param is_CI: True for loading a much reduced dataset. Used in CI.
:return: a tuple of three dataframes, first holding the training data, second validation, third test. \
Each has 3 keys: normalized_images has all the input images, rescaled to mean 0.5 and std 0.225, \
labels has the labels of each image, while images has the unmodified PIL images.
"""
try:
from datasets import load_dataset, utils
from torchvision import transforms
except ImportError as e:
raise RuntimeError(
"Torchvision and Huggingface datasets are required to load and "
"process the imagenet dataset."
) from e

utils.logging.set_verbosity_error()

preprocess_rgb = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225]),
]
)

def _process_dataset(ds):
processed_ds = {"normalized_images": [], "labels": [], "images": []}
for i, item in enumerate(ds):
if item["image"].mode == "RGB":
processed_ds["normalized_images"].append(preprocess_rgb(item["image"]))
processed_ds["images"].append(item["image"])
processed_ds["labels"].append(item["label"])
return pd.DataFrame.from_dict(processed_ds)

if is_CI:
tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="valid")
if keep_labels is not None:
Xuzzo marked this conversation as resolved.
Show resolved Hide resolved
tiny_imagenet = tiny_imagenet.filter(
lambda item: item["label"] in keep_labels
)
split = tiny_imagenet.shard(2, 0)
tiny_imagenet_test = tiny_imagenet.shard(2, 1)
tiny_imagenet_train = split.shard(5, 0)
tiny_imagenet_val = split.shard(5, 1)
train_ds = _process_dataset(tiny_imagenet_train)
val_ds = _process_dataset(tiny_imagenet_val)
test_ds = _process_dataset(tiny_imagenet_test)
return train_ds, val_ds, test_ds
else:
tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="train")

if downsample_ds_to_fraction != 1:
tiny_imagenet = tiny_imagenet.shard(1 / downsample_ds_to_fraction, 0)
if keep_labels is not None:
tiny_imagenet = tiny_imagenet.filter(lambda item: item["label"] in keep_labels)

split_ds = tiny_imagenet.train_test_split(
train_size=1 - test_size,
seed=random_state,
)
test_ds = _process_dataset(split_ds["test"])

split_ds = split_ds["train"].train_test_split(
train_size=train_size,
seed=random_state,
)
train_ds = _process_dataset(split_ds["train"])
val_ds = _process_dataset(split_ds["test"])
return train_ds, val_ds, test_ds


def save_model(
model: TorchModel,
train_loss: List[float],
val_loss: List[float],
model_name: str,
):
"""Saves the model weights, with also its training and validation losses.

:param model: trained model
:param train_loss: list of training losses, one per epoch
:param val_loss: list of validation losses, also one per epoch
:param model_name: model name, used for saving the files
"""
torch.save(model.state_dict(), imgnet_model_data_path / f"{model_name}_weights.pth")
with open(
imgnet_model_data_path / f"{model_name}_train_val_loss.pkl", "wb"
) as file:
pkl.dump([train_loss, val_loss], file)


def load_model(model: TorchModel, model_name: str) -> Tuple[List[float], List[float]]:
"""Given the model and the model name, it loads the model weights from the file {model_name}_weights.pth.
Then, it also loads and returns the training and validation losses.

:param model: model
:param model_name: name of the model whose weights have been previously saved
:return: two lists, one with training and one with validation losses.
"""
model.load_state_dict(
torch.load(imgnet_model_data_path / f"{model_name}_weights.pth")
)
with open(
imgnet_model_data_path / f"{model_name}_train_val_loss.pkl", "rb"
) as file:
train_loss, val_loss = pkl.load(file)
return train_loss, val_loss


def plot_sample_images(
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
dataset: pd.DataFrame,
n_images_per_class: int = 3,
):
"""Given the preprocessed imagenet dataset (or a subset of it), it plots \
a number n_images_per_class of images for each class.

:param dataset: imagenet dataset
:param n_images_per_class: number of images per class to plot
"""
labels = dataset["labels"].unique()
fig, axes = plt.subplots(nrows=n_images_per_class, ncols=len(labels))
fig.suptitle("Examples of training images")
for class_idx, class_label in enumerate(labels):
for img_idx, (_, img_data) in enumerate(
dataset[dataset["labels"] == class_label].iterrows()
):
axes[img_idx, class_idx].imshow(img_data["images"])
axes[img_idx, class_idx].axis("off")
axes[img_idx, class_idx].set_title(f"img label: {class_label}")
if img_idx + 1 >= n_images_per_class:
break
plt.show()
Xuzzo marked this conversation as resolved.
Show resolved Hide resolved


def plot_top_bottom_if_images(
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
subset_influences: "NDArray[np.float_]",
subset_images: List[JpegImageFile],
num_to_plot: int,
):
"""Given the influence values and the related images, it plots a number 2*num_to_plot of images,
of which those on the right column have the lowest influence, those on the right the highest.

:param subset_influences: an array with influence values
:param subset_images: a list of images
:param num_to_plot: int, number of high and low influence images to plot
"""
top_if_idxs = np.argsort(subset_influences)[-num_to_plot:]
bottom_if_idxs = np.argsort(subset_influences)[:num_to_plot]

fig, axes = plt.subplots(nrows=num_to_plot, ncols=2)
fig.suptitle("Botton (left) and top (right) influences")

for plt_idx, img_idx in enumerate(bottom_if_idxs):
axes[plt_idx, 0].set_title(f"img influence: {subset_influences[img_idx]:0f}")
axes[plt_idx, 0].imshow(subset_images[img_idx])
axes[plt_idx, 0].axis("off")

for plt_idx, img_idx in enumerate(top_if_idxs):
axes[plt_idx, 1].set_title(f"img influence: {subset_influences[img_idx]:0f}")
axes[plt_idx, 1].imshow(subset_images[img_idx])
axes[plt_idx, 1].axis("off")

plt.show()


def plot_train_val_loss(train_loss: List[float], val_loss: List[float]):
"""Plots the train and validation loss

:param train_loss: list of training losses, one per epoch
:param val_loss: list of validation losses, one per epoch
"""
_, ax = plt.subplots()
ax.plot(train_loss, label="Train")
ax.plot(val_loss, label="Val")
ax.set_ylabel("Loss")
ax.set_xlabel("Train epoch")
ax.legend()
plt.show()


def corrupt_imagenet(
dataset: pd.DataFrame,
fraction_to_corrupt: float,
avg_influences: "NDArray[np.float_]",
) -> Tuple[pd.DataFrame, Dict[Any, List[int]]]:
"""Given the preprocessed tiny imagenet dataset (or a subset of it),
it takes a fraction of the images with the highest influence and (randomly)
flips their labels.

:param dataset: preprocessed tiny imagenet dataset
:param fraction_to_corrupt: float, fraction of data to corrupt
:param avg_influences: average influences of each training point on the test set in the \
non-corrupted case.
:return: first element is the corrupted dataset, second is the list of indices \
related to the images that have been corrupted.
"""
indices_to_corrupt = []
labels = dataset["labels"].unique()
corrupted_dataset = deepcopy(dataset)
corrupted_indices = {l: [] for l in labels}

avg_influences_series = pd.DataFrame()
avg_influences_series["avg_influences"] = avg_influences
avg_influences_series["labels"] = dataset["labels"]

for label in labels:
class_data = avg_influences_series[avg_influences_series["labels"] == label]
num_corrupt = int(fraction_to_corrupt * len(class_data))
indices_to_corrupt = class_data.nlargest(
num_corrupt, "avg_influences"
).index.tolist()
wrong_labels = [l for l in labels if l != label]
for img_idx in indices_to_corrupt:
sample_label = np.random.choice(wrong_labels)
corrupted_dataset.at[img_idx, "labels"] = sample_label
corrupted_indices[sample_label].append(img_idx)
return corrupted_dataset, corrupted_indices


def get_mean_corrupted_influences(
corrupted_dataset: pd.DataFrame,
corrupted_indices: Dict[Any, List[int]],
avg_corrupted_influences: "NDArray[np.float_]",
) -> pd.DataFrame:
"""Given a corrupted dataset, it returns a dataframe with average influence for each class
and separately for corrupted (and non) point.

:param corrupted_dataset: corrupted dataset as returned by get_corrupted_imagenet
:param corrupted_indices: list of corrupted indices, as returned by get_corrupted_imagenet
:param avg_corrupted_influences: average influence of each training point on the test dataset
:return: a dataframe holding the average influence of corrupted and non-corrupted data
"""
labels = corrupted_dataset["labels"].unique()
avg_label_influence = pd.DataFrame(
columns=["label", "avg_non_corrupted_infl", "avg_corrupted_infl", "score_diff"]
)
for idx, label in enumerate(labels):
avg_influences_series = pd.Series(avg_corrupted_influences)
class_influences = avg_influences_series[corrupted_dataset["labels"] == label]
corrupted_infl = class_influences[
class_influences.index.isin(corrupted_indices[label])
]
non_corrupted_infl = class_influences[
~class_influences.index.isin(corrupted_indices[label])
]
avg_non_corrupted = np.mean(non_corrupted_infl)
avg_corrupted = np.mean(corrupted_infl)
avg_label_influence.loc[idx] = [
label,
avg_non_corrupted,
avg_corrupted,
avg_non_corrupted - avg_corrupted,
]
return avg_label_influence


def plot_corrupted_influences_distribution(
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
corrupted_dataset: pd.DataFrame,
corrupted_indices: Dict[Any, List[int]],
avg_corrupted_influences: "NDArray[np.float_]",
):
"""Given a corrupted dataset, it plots the histogram with the distribution of
influence values. This is done separately for each label: each has a plot where
the distribution of the influence of non-corrupted points is compared to that of corrupted ones

:param corrupted_dataset: corrupted dataset as returned by get_corrupted_imagenet
:param corrupted_indices: list of corrupted indices, as returned by get_corrupted_imagenet
:param avg_corrupted_influences: average influence of each training point on the test dataset
:return: a dataframe holding the average influence of corrupted and non-corrupted data
"""
labels = corrupted_dataset["labels"].unique()
fig, axes = plt.subplots(nrows=1, ncols=2)
fig.suptitle("Distribution of corrupted and clean influences.")
for idx, label in enumerate(labels):
avg_influences_series = pd.Series(avg_corrupted_influences)
class_influences = avg_influences_series[corrupted_dataset["labels"] == label]
corrupted_infl = class_influences[
class_influences.index.isin(corrupted_indices[label])
]
non_corrupted_infl = class_influences[
~class_influences.index.isin(corrupted_indices[label])
]
axes[idx].hist(
non_corrupted_infl, label="non corrupted data", density=True, alpha=0.7
)
axes[idx].hist(
corrupted_infl,
label="corrupted data",
density=True,
alpha=0.7,
color="green",
)
axes[idx].set_xlabel("influence values")
axes[idx].set_ylabel("Distribution")
axes[idx].set_title(f"Influences for {label=}")
axes[idx].legend()
plt.show()
Loading