diff --git a/src/decoding.cc b/src/decoding.cc index 90fa75b14..418389e2c 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -70,9 +70,14 @@ namespace ctranslate2 { StorageView& ids) { if (!decoder.output_layer_is_updated()) return; + ctranslate2::Device device = ids.device(); + if (device != Device::CPU) + ids = ids.to(Device::CPU); auto* ids_data = ids.data(); for (dim_t i = 0; i < ids.size(); ++i) ids_data[i] = decoder.to_original_word_id(ids_data[i]); + if (ids.device() != device) + ids = ids.to(device); } template