Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cnn influence example #195

Merged
merged 31 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
93ec7f2
WIP influence example using imagenet and resnet
Xuzzo Nov 8, 2022
7b73a98
WIP extending example of cnn with ifs
Xuzzo Nov 9, 2022
17a45e9
WIP notebook
Xuzzo Nov 11, 2022
7990df1
WIP still working on damn example
Xuzzo Nov 11, 2022
b8df585
WIP writing docs for notebook
Xuzzo Nov 13, 2022
4800233
WIP writing theory in example
Xuzzo Nov 23, 2022
9c62032
Merge branch 'develop' into cnn_influence_example
Xuzzo Dec 6, 2022
7ada9a6
add theory appendix to imagenet notebook
Xuzzo Dec 6, 2022
1f36c05
moved methods to notebook support
Xuzzo Dec 6, 2022
c6c6da8
update to imagenet notebook - docs to notebook support methods
Xuzzo Dec 7, 2022
d203ad4
cosmetic changes to notebook
Xuzzo Dec 9, 2022
7d38b89
trying to solve tox issue in ci
Xuzzo Dec 9, 2022
b70432d
fix ci data loading in notebook
Xuzzo Dec 9, 2022
b16287e
minor changes to docs
Xuzzo Dec 9, 2022
bcc745a
update changelog and fix typing in notebook_support
Xuzzo Dec 12, 2022
cb18bdf
including progress bar
Xuzzo Dec 19, 2022
f371af8
addressing MR comments
Xuzzo Dec 20, 2022
0c7f466
dummy commit to re-trigger pipelines
Xuzzo Dec 21, 2022
d156f6e
minor changes to notebooks
Xuzzo Dec 21, 2022
eb63458
add sphinx hidden to cells
Xuzzo Dec 22, 2022
461eb27
minor changes to notebook
Xuzzo Dec 22, 2022
b4793f3
Forward calls to wrapped torch model
mdbenito Dec 27, 2022
f059b6b
Return ndarrays in TorchModel.fit
mdbenito Dec 27, 2022
d193eec
Types, strings, etc.
mdbenito Dec 27, 2022
826ae08
Remove clutter from notebook, some refactoring, rephrasing and tweaking.
mdbenito Dec 27, 2022
1e9e0a3
git ignore saved models
mdbenito Dec 27, 2022
c7256fc
Merge pull request #235 from appliedAI-Initiative/fix/cnn-influence
Xuzzo Dec 28, 2022
170970d
dummy commit to trigger pipeline
Xuzzo Dec 28, 2022
a31b263
fix typing
Xuzzo Dec 28, 2022
2d3a719
minor changes to notebooks
Xuzzo Dec 28, 2022
9f704d0
add req-notebooks and remove InternalDataset
Xuzzo Dec 28, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,258 changes: 1,258 additions & 0 deletions notebooks/influence_imagenet.ipynb

Large diffs are not rendered by default.

23 changes: 10 additions & 13 deletions notebooks/influence_wine.ipynb

Large diffs are not rendered by default.

316 changes: 316 additions & 0 deletions notebooks/notebook_support.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from cloudpickle import pickle as pkl

from pydvl.utils import Dataset

try:
import torch

_TORCH_INSTALLED = True
except ImportError:
_TORCH_INSTALLED = False

if TYPE_CHECKING:
from numpy.typing import NDArray

imgnet_model_data_path = Path().resolve().parent / "data/imgnet_model"


def plot_dataset(
train_ds: Tuple["NDArray", "NDArray"],
Expand Down Expand Up @@ -258,3 +271,306 @@ def _handle_legend(scatter):
edgecolors="r",
s=80,
)


def load_preprocess_imagenet(
train_size: float,
test_size: float,
downsample_ds_to_fraction: float = 1,
keep_labels: Optional[List] = None,
random_state: Optional[int] = None,
is_CI: bool = False,
):
"""Loads the tiny imagened dataset from huggingface and preprocesses it
for model input.

:param train_size: fraction of indices to use for training
:param test_size: fraction of data to use for testing
:param downsample_ds_to_fraction: which fraction of the full dataset to keep. \
E.g. downsample_ds_to_fraction=0.2 only 20% of the dataset is kept
:param keep_labels: which of the original labels to keep. \
E.g. keep_labels=[10,20] only returns the images with labels 10 and 20.
:param random_state: Random state. Fix this for reproducibility of sampling.
:param is_CI: True for loading a much reduced dataset. Used in CI.
:return: a tuple of three dataframes, first holding the training data, second validation, third test. \
Each has 3 keys: normalized_images has all the input images, rescaled to mean 0.5 and std 0.225, \
labels has the labels of each image, while images has the unmodified images.
"""
try:
from datasets import load_dataset
from torchvision import transforms
except ImportError as e:
raise RuntimeError(
"Torchvision and Huggingface datasets are required to load and "
"process the imagenet dataset."
) from e

preprocess_rgb = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225]),
]
)

def _process_dataset(ds):
processed_ds = {"normalized_images": [], "labels": [], "images": []}
for i, item in enumerate(ds):
if item["image"].mode == "RGB":
processed_ds["normalized_images"].append(preprocess_rgb(item["image"]))
processed_ds["images"].append(item["image"])
processed_ds["labels"].append(item["label"])
return pd.DataFrame.from_dict(processed_ds)

if is_CI:
tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="val")
tiny_imagenet = tiny_imagenet.shard(1 / 10, 0)
else:
tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="train")

if downsample_ds_to_fraction != 1:
tiny_imagenet = tiny_imagenet.shard(1 / downsample_ds_to_fraction, 0)
if keep_labels is not None:
tiny_imagenet = tiny_imagenet.filter(lambda item: item["label"] in keep_labels)

split_ds = tiny_imagenet.train_test_split(
train_size=1 - test_size,
seed=random_state,
)
test_ds = _process_dataset(split_ds["test"])

split_ds = split_ds["train"].train_test_split(
train_size=train_size,
seed=random_state,
)
train_ds = _process_dataset(split_ds["train"])
val_ds = _process_dataset(split_ds["test"])
return train_ds, val_ds, test_ds


def save_model(model, train_loss, val_loss, model_name):
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
"""Saves the model weights, with also its training and validation losses.

:param model: trained model
:param train_loss: list of training losses, one per epoch
:param val_loss: list of validation losses, also one per epoch
:param model_name: model name, used for saving the files
"""
torch.save(model.state_dict(), imgnet_model_data_path / f"{model_name}_weights.pth")
with open(
imgnet_model_data_path / f"{model_name}_train_val_loss.pkl", "wb"
) as file:
pkl.dump([train_loss, val_loss], file)


def load_model(model, model_name):
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
"""Given the model and the model name, it loads the model weights from the file {model_name}_weights.pth.
Then , it also loads and returns the training and validation losses.

:param model: model
:param model_name: name of the model whose weights have been previously saved
:return: two lists, one with training and one with validation losses.
"""
model.load_state_dict(
torch.load(imgnet_model_data_path / f"{model_name}_weights.pth")
)
with open(
imgnet_model_data_path / f"{model_name}_train_val_loss.pkl", "rb"
) as file:
train_loss, val_loss = pkl.load(file)
return train_loss, val_loss


def save_results(results, file_name):
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
"""Saves (pickles) any file to {file_name}.pkl

:param results: any serializable object
:param file_name: string, file name where to save the object
"""
with open(imgnet_model_data_path / f"{file_name}", "wb") as file:
pkl.dump(results, file)


def load_results(file_name):
"""Loads the pickle file {file_name}.pkl

:param file_name: string, file name where the object is saved
:return: saved object
"""
with open(imgnet_model_data_path / f"{file_name}", "rb") as file:
results = pkl.load(file)
return results


def plot_sample_images(
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
dataset,
n_images_per_class=3,
Xuzzo marked this conversation as resolved.
Show resolved Hide resolved
):
"""Given the preprocessed imagenet dataset (or a subset of it), it plots \
a number n_images_per_class of images for each class.

:param dataset: imagenet dataset
:param n_images_per_class: int, number of images per class to plot
Xuzzo marked this conversation as resolved.
Show resolved Hide resolved
"""
labels = dataset["labels"].unique()
fig, axes = plt.subplots(nrows=n_images_per_class, ncols=len(labels))
fig.suptitle("Examples of training images")
for class_idx, class_label in enumerate(labels):
for img_idx, (_, img_data) in enumerate(
dataset[dataset["labels"] == class_label].iterrows()
):
axes[img_idx, class_idx].imshow(img_data["images"])
axes[img_idx, class_idx].axis("off")
axes[img_idx, class_idx].set_title(f"img label: {class_label}")
if img_idx + 1 >= n_images_per_class:
break
plt.show()
Xuzzo marked this conversation as resolved.
Show resolved Hide resolved


def plot_top_bottom_if_images(
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
subset_influences,
subset_images,
num_to_plot,
):
"""Given the influence values and the related images, it plots a number 2* num_to_plot of images,
of which those on the right column have the lowest influence, those on the right the highest.

:param subset_influences: an array with influence values
:param subset_images: an array of images
:param num_to_plot: int, number of high and low influence images to plot
"""
top_if_idxs = np.argsort(subset_influences)[-num_to_plot:]
bottom_if_idxs = np.argsort(subset_influences)[:num_to_plot]

fig, axes = plt.subplots(nrows=num_to_plot, ncols=2)
fig.suptitle("Botton (left) and top (right) influences")

for plt_idx, img_idx in enumerate(bottom_if_idxs):
axes[plt_idx, 0].set_title(f"img influence: {subset_influences[img_idx]:0f}")
axes[plt_idx, 0].imshow(subset_images[img_idx])
axes[plt_idx, 0].axis("off")

for plt_idx, img_idx in enumerate(top_if_idxs):
axes[plt_idx, 1].set_title(f"img influence: {subset_influences[img_idx]:0f}")
axes[plt_idx, 1].imshow(subset_images[img_idx])
axes[plt_idx, 1].axis("off")

plt.show()


def plot_train_val_loss(train_loss, val_loss):
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
"""Plots the train and validation loss

:param train_loss: list of training losses, one per epoch
:param val_loss: list of validation losses, one per epoch
"""
_, ax = plt.subplots()
ax.plot(train_loss, label="Train")
ax.plot(val_loss, label="Val")
ax.set_ylabel("Loss")
ax.set_xlabel("Train epoch")
ax.legend()
plt.show()


def get_corrupted_imagenet(dataset, fraction_to_corrupt, avg_influences):
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
"""Given the preprocessed tiny imagenet dataset (or a subset of it),
it takes a fraction of the images with the highest influence and (randomly)
flips their labels.

:param dataset: preprocessed tiny imagenet dataset
:param fraction_to_corrupt: float, fraction of data to corrupt
:param avg_influences: average influences of each training point on the test set in the \
non-corrupted case.
:return: first element is the corrupted dataset, second is the list of entries \
that have been corrupted.
"""
indices_to_corrupt = []
labels = dataset["labels"].unique()
corrupted_dataset = deepcopy(dataset)
corrupted_indices = {l: [] for l in labels}

avg_influences_series = pd.DataFrame()
avg_influences_series["avg_influences"] = avg_influences
avg_influences_series["labels"] = dataset["labels"]

for label in labels:
class_data = avg_influences_series[avg_influences_series["labels"] == label]
num_corrupt = int(fraction_to_corrupt * len(class_data))
indices_to_corrupt = class_data.nlargest(
num_corrupt, "avg_influences"
).index.tolist()
wrong_labels = [l for l in labels if l != label]
for img_idx in indices_to_corrupt:
sample_label = np.random.choice(wrong_labels)
corrupted_dataset.at[img_idx, "labels"] = sample_label
corrupted_indices[sample_label].append(img_idx)
return corrupted_dataset, corrupted_indices


def plot_influence_distribution_by_label(influences, dataset):
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
"""For each label in dataset it plots the histogram of the distribution of
influence values.

:param influences: array of influences
:param dataset: (preprocessed) tiny-imagenet dataset
Xuzzo marked this conversation as resolved.
Show resolved Hide resolved
"""
_, ax = plt.subplots()
labels = dataset["labels"].unique()
for label in labels:
ax.hist(influences[dataset["labels"] == label], label=label, alpha=0.7)
ax.set_xlabel("influence values")
ax.set_ylabel("number of points")
ax.set_title("Influence distribution")
ax.legend()
plt.show()


def plot_corrupted_influences_distribution(
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
corrupted_dataset,
corrupted_indices,
avg_corrupted_influences,
):
"""Given a corrupted dataset, it plots the histogram with the distribution of
influence values. This is done separately for each label: each has a plot where
the distribution of the influence of non-corrupted points is compared to that of corrupted ones

:param corrupted_dataset: corrupted dataset as returned by get_corrupted_imagenet
:param corrupted_indices: list of corrupted indices, as returned by get_corrupted_imagenet
:param avg_corrupted_influences: average influence of each training point on the test dataset
"""
labels = corrupted_dataset["labels"].unique()
fig, axes = plt.subplots(nrows=1, ncols=2)
fig.suptitle("Distribution of corrupted and clean influences.")
avg_label_influence = pd.DataFrame(
columns=["label", "avg_non_corrupted_infl", "avg_corrupted_infl"]
)
for idx, label in enumerate(labels):
avg_influences_series = pd.Series(avg_corrupted_influences)
class_influences = avg_influences_series[corrupted_dataset["labels"] == label]
corrupted_infl = class_influences[
class_influences.index.isin(corrupted_indices[label])
]
non_corrupted_infl = class_influences[
~class_influences.index.isin(corrupted_indices[label])
]
avg_label_influence.loc[idx] = [
label,
np.mean(non_corrupted_infl),
np.mean(corrupted_infl),
]
axes[idx].hist(
non_corrupted_infl, label="non corrupted data", density=True, alpha=0.7
)
axes[idx].hist(
corrupted_infl,
label="corrupted data",
density=True,
alpha=0.7,
color="green",
)
axes[idx].set_xlabel("influence values")
axes[idx].set_ylabel("Distribution")
axes[idx].set_title(f"Influences for {label=}")
axes[idx].legend()
plt.show()
return avg_label_influence.astype({"label": "int32"})
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pymemcache
cloudpickle
tqdm
matplotlib
datasets
AnesBenmerzoug marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion src/pydvl/influence/conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def conjugate_gradient(A: "NDArray", batch_y: "NDArray") -> "NDArray":
"""
batch_cg = []
for y in batch_y:
y_cg, _ = cg(A.T, y)
y_cg, _ = cg(A, y)
batch_cg.append(y_cg)
return np.asarray(batch_cg)

Expand Down
Loading