You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using a torch model with TorchModuleWrapper, the mixed_precision doesnt work.
I guess somehow in the call of TorchModuleWrapper we are supposed to wrap the call to the torch model with with torch.cuda.amp.autocast():
Here is some code that doesnt work:
import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import keras
keras.mixed_precision.set_global_policy("mixed_float16")
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
train_dataloader = DataLoader(training_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to("cuda")
inputs = keras.layers.Input(shape=(1, 28,28))
outputs = keras.layers.TorchModuleWrapper(model)(inputs)
keras_model = keras.models.Model(inputs,outputs)
keras_model.compile( optimizer=keras.optimizers.SGD(learning_rate=1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))
keras_model.fit(train_dataloader)
The text was updated successfully, but these errors were encountered:
Thanks for reporting this issue.
The error you are getting is Error encountered: mat1 and mat2 must have the same dtype, but got Half and Float because data type of input is half(float16) and weights or operations inside the model is Float(float32). Since you are using keras.mixed_precision.set_global_policy("mixed_float16"), you can explicitly call model=model.half() which will convert all the model parameters to float16.
Attaching gist for your reference.
When using a torch model with TorchModuleWrapper, the mixed_precision doesnt work.
I guess somehow in the call of TorchModuleWrapper we are supposed to wrap the call to the torch model with
with torch.cuda.amp.autocast():
Here is some code that doesnt work:
The text was updated successfully, but these errors were encountered: