Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Image GPT #108

Merged
merged 2 commits into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions pl_bolts/models/vision/image_gpt/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def forward(self, x):

class GPT2(pl.LightningModule):
def __init__(
self,
embed_dim: int,
heads: int,
layers: int,
num_positions: int,
vocab_size: int,
num_classes: int
self,
embed_dim: int,
heads: int,
layers: int,
num_positions: int,
vocab_size: int,
num_classes: int,
):
"""
GPT-2 from `language Models are Unsupervised Multitask Learners <https://d4mucfpksywv.cloudfront.net/
Expand All @@ -55,7 +55,7 @@ def __init__(
batch_size = 32
vocab_size = 16
x = torch.randint(0, vocab_size, (seq_len, batch_size))
model = GPT2(embed_dim=32, heads=2, layers=2, num_positions=2, vocab_size=vocab_size, num_classes=4)
model = GPT2(embed_dim=32, heads=2, layers=2, num_positions=seq_len, vocab_size=vocab_size, num_classes=4)
results = model(x)
"""
super(GPT2, self).__init__()
Expand All @@ -70,16 +70,22 @@ def _init_sos_token(self):
nn.init.normal_(self.sos)

def _init_embeddings(self):
self.token_embeddings = nn.Embedding(self.hparams.vocab_size, self.hparams.embed_dim)
self.position_embeddings = nn.Embedding(self.hparams.num_positions, self.hparams.embed_dim)
self.token_embeddings = nn.Embedding(
self.hparams.vocab_size, self.hparams.embed_dim
)
self.position_embeddings = nn.Embedding(
self.hparams.num_positions, self.hparams.embed_dim
)

def _init_layers(self):
self.layers = nn.ModuleList()
for _ in range(self.hparams.layers):
self.layers.append(Block(self.hparams.embed_dim, self.hparams.heads))

self.ln_f = nn.LayerNorm(self.hparams.embed_dim)
self.head = nn.Linear(self.hparams.embed_dim, self.hparams.vocab_size, bias=False)
self.head = nn.Linear(
self.hparams.embed_dim, self.hparams.vocab_size, bias=False
)
self.clf_head = nn.Linear(self.hparams.embed_dim, self.hparams.num_classes)

def forward(self, x, classify=False):
Expand Down
55 changes: 31 additions & 24 deletions pl_bolts/models/vision/image_gpt/igpt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ def _shape_input(x):

class ImageGPT(pl.LightningModule):
def __init__(
self,
datamodule: LightningDataModule = None,
embed_dim: int = 16,
heads: int = 2,
layers: int = 2,
pixels: int = 28,
vocab_size: int = 16,
num_classes: int = 10,
classify: bool = False,
batch_size: int = 64,
learning_rate: float = 1e-2,
steps: int = 25_000,
data_dir: str = '.',
num_workers: int = 8,
**kwargs
self,
datamodule: LightningDataModule = None,
embed_dim: int = 16,
heads: int = 2,
layers: int = 2,
pixels: int = 28,
vocab_size: int = 16,
num_classes: int = 10,
classify: bool = False,
batch_size: int = 64,
learning_rate: float = 1e-2,
steps: int = 25_000,
data_dir: str = ".",
num_workers: int = 8,
**kwargs,
):
"""
**Paper**: `Generative Pretraining from Pixels
Expand Down Expand Up @@ -135,7 +135,9 @@ def __init__(

# default to MNIST if no datamodule given
if datamodule is None:
datamodule = FashionMNISTDataModule(self.hparams.data_dir, num_workers=self.hparams.num_workers)
datamodule = FashionMNISTDataModule(
self.hparams.data_dir, num_workers=self.hparams.num_workers
)
self.hparams.pixels = datamodule.size(1)
self.hparams.num_classes = datamodule.num_classes

Expand Down Expand Up @@ -163,11 +165,17 @@ def configure_optimizers(self):
return [optim], [sched]

def forward(self, x, classify=False):
x = _shape_input(x)

# TODO(teddykoker): this is a hack to quantize images into `vocab_size` bins.
# This only works with 1 channel images; something like KNN needs to be used
# for RGB. Assumes data is in [0.0, 1.0].
x = torch.round(x * (self.hparams.vocab_size - 1)).long()

return self.gpt(x, classify)

def training_step(self, batch, batch_idx):
x, y = batch
x = _shape_input(x)

if self.hparams.classify:
clf_logits = self(x, classify=True)
Expand All @@ -181,17 +189,16 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch, batch_idx):
x, y = batch
x = _shape_input(x)

result = {}
if self.hparams.classify:
clf_logits = self.gpt(x, classify=True)
clf_logits = self(x, classify=True)
loss = self.criterion(clf_logits, y)
_, preds = torch.max(clf_logits, 1)
correct = preds == y
result.update({"val_loss": loss, "correct": correct})
else:
logits = self.gpt(x)
logits = self(x)
logits = logits.view(-1, logits.size(-1))
loss = self.criterion(logits, x.view(-1).long())
result.update({"val_loss": loss})
Expand Down Expand Up @@ -235,7 +242,7 @@ def test_dataloader(self):
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--embed_dim", type=int, default=16)
parser.add_argument("--dataset", type=str, default='fashion_mnist')
parser.add_argument("--dataset", type=str, default="fashion_mnist")
parser.add_argument("--data_dir", type=str, default=os.getcwd())
parser.add_argument("--heads", type=int, default=2)
parser.add_argument("--layers", type=int, default=8)
Expand All @@ -247,7 +254,7 @@ def add_model_specific_args(parent_parser):
return parser


if __name__ == '__main__':
if __name__ == "__main__":
from argparse import ArgumentParser

parser = ArgumentParser()
Expand All @@ -259,10 +266,10 @@ def add_model_specific_args(parent_parser):
parser = ImageGPT.add_model_specific_args(parser)
args = parser.parse_args()

if args.dataset == 'fashion_mnist':
if args.dataset == "fashion_mnist":
datamodule = FashionMNISTDataModule.from_argparse_args(args)

elif args.dataset == 'imagenet128':
elif args.dataset == "imagenet128":
datamodule = ImagenetDataModule.from_argparse_args(args)

model = ImageGPT(**args.__dict__, datamodule=datamodule)
Expand Down
25 changes: 18 additions & 7 deletions tests/models/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,33 @@ def test_igpt(tmpdir):
dm = MNISTDataModule(tmpdir, normalize=False)
model = ImageGPT(datamodule=dm)

trainer = pl.Trainer(limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1)
trainer = pl.Trainer(
limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
)
trainer.fit(model)
trainer.test()
assert trainer.callback_metrics['test_loss'] < 1.7
assert trainer.callback_metrics["test_loss"] < 1.7

model = ImageGPT(classify=True)
trainer = pl.Trainer(limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1)
trainer = pl.Trainer(
limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
)
trainer.fit(model)


def test_gpt2(tmpdir):

seq_len = 17
batch_size = 32
classes = 10
x = torch.randint(0, 10, (seq_len, batch_size))

model = GPT2(embed_dim=16, heads=2, layers=2, num_positions=28 * 28, vocab_size=16, num_classes=classes)
vocab_size = 16
x = torch.randint(0, vocab_size, (seq_len, batch_size))

model = GPT2(
embed_dim=16,
heads=2,
layers=2,
num_positions=seq_len,
vocab_size=vocab_size,
num_classes=10,
)
model(x)