Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorFlow] Adding LeViT #19152

Closed
wants to merge 18 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: aligning till attention biases
ariG23498 committed Oct 9, 2022
commit 57f5f74dc3e7036a1137e92ca6afaf4b330a85eb
33 changes: 16 additions & 17 deletions src/transformers/models/levit/modeling_tf_levit.py
Original file line number Diff line number Diff line change
@@ -16,12 +16,11 @@

import itertools
from dataclasses import dataclass
from typing import Optional, Tuple, Dict
from numpy import indices
from typing import Dict, Optional, Tuple

import tensorflow as tf
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, CategoricalCrossentropy
from tensorflow.keras import backend as K
from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy, MeanSquaredError

from ...modeling_outputs import ModelOutput
from ...modeling_tf_outputs import (
@@ -59,7 +58,7 @@
@dataclass
class TFLevitForImageClassificationWithTeacherOutput(ModelOutput):
"""
Output type of [`LevitForImageClassificationWithTeacher`].
Output type of [`TFLevitForImageClassificationWithTeacher`].

Args:
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
@@ -95,18 +94,20 @@ def __init__(
filters=out_channels,
kernel_size=kernel_size,
strides=stride,
padding=(padding, padding), # TODO @ariG23498: Make sure the padding is a tuple
padding="SAME", # TODO @ariG23498: Make sure the padding is a tuple
dilation_rate=dilation,
groups=groups,
use_bias=False,
data_format="channels_first",
data_format="channels_last",
name="convolution",
)
# The epsilon and momentum used here are the defaults in torch batch norm layer.
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-05, momentum=0.1, name="batch_norm")

def call(self, embeddings, training=None):
embeddings = tf.transpose(embeddings, perm=(0, 2, 3, 1))
embeddings = self.convolution(embeddings, training=training)
embeddings = tf.transpose(embeddings, perm=(0, 3, 1, 2))
embeddings = self.batch_norm(embeddings, training=training)
return embeddings

@@ -181,6 +182,7 @@ def call(self, pixel_values, training=None):
embeddings = self.activation_layer_3(embeddings)
embeddings = self.embedding_layer_4(embeddings, training=training)
# Flatten the embeddings
num_channels = tf.shape(embeddings)[1]
flattended_embeddings = tf.reshape(embeddings, shape=(batch_size, num_channels, -1))
# Transpose the channel and spatial axis of the flattened embeddings
transpose_embeddings = tf.transpose(flattended_embeddings, perm=(0, 2, 1))
@@ -275,7 +277,7 @@ def build(self, input_shape):
self.attention_bias_idxs = tf.Variable(
initial_value=tf.reshape(self.indices, (self.len_points, self.len_points)),
trainable=False, # this is a registered buffer and not a parameter
dtype=tf.float32,
dtype=tf.int32,
name="attention_bias_idxs",
)
super().build(input_shape)
@@ -293,6 +295,8 @@ def get_attention_biases(self, device, training=None):
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
print("INFO biases cache", self.attention_biases.shape)
print("INFO biases index", self.attention_bias_idxs.shape)
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]

@@ -381,7 +385,7 @@ def build(self, input_shape):
self.attention_bias_idxs = tf.Variable(
initial_value=tf.reshape(self.indices, (self.len_points_, self.len_points)),
trainable=False,
dtype=tf.float32,
dtype=tf.int32,
name="attention_bias_idxs",
)
super().build(input_shape)
@@ -498,13 +502,12 @@ def __init__(
self.config = config
self.resolution_in = resolution_in
# resolution_in is the intial resolution, resolution_out is final resolution after downsampling

for idx in range(depths):
for index in range(depths):
self.layers.append(
TFLevitResidualLayer(
TFLevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
self.config.drop_path_rate,
name=f"layers.{idx}",
name=f"layers.{index}",
)
)
if mlp_ratio > 0:
@@ -513,19 +516,15 @@ def __init__(
TFLevitResidualLayer(
TFLevitMLPLayer(hidden_sizes, hidden_dim),
self.config.drop_path_rate,
name=f"layers.{idx}",
name=f"layers.{index}",
)
)

if down_ops[0] == "Subsample":

print("info", self.config.hidden_sizes)
print("info", idx)
self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
self.layers.append(
TFLevitAttentionSubsample(
input_dim=self.config.hidden_sizes[idx],
output_dim=self.config.hidden_sizes[idx + 1],
*self.config.hidden_sizes[idx : idx + 2],
key_dim=down_ops[1],
num_attention_heads=down_ops[2],
attention_ratio=down_ops[3],