diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 823f790f41030e..8a461760ef0c9b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -7623,8 +7623,15 @@ def _get_paddle_place(place): device_id = int(device_id) return core.IPUPlace(device_id) + place_info_list = place.split(':', 1) + device_type = place_info_list[0] + if device_type in core.get_all_custom_device_type(): + device_id = place_info_list[1] + device_id = int(device_id) + return core.CustomPlace(device_type, device_id) + raise ValueError( - f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace and IPUPlace, but received {place}." + f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace, IPUPlace and CustomPlace, but received {place}." )