diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 12bd76ba60..06444a8769 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -442,7 +442,8 @@ def astype(self, dtype, device=None, *_args, **_kwargs): _kwargs: additional kwargs (currently unused). Returns: - data array instance + ``MetaTensor`` when a torch dtype is given (metadata is preserved), + or ``np.ndarray`` when a numpy dtype is given. """ if isinstance(dtype, str): mod_str, *dtype = dtype.split(".", 1) @@ -453,7 +454,7 @@ def astype(self, dtype, device=None, *_args, **_kwargs): out_type: type[torch.Tensor] | type[np.ndarray] | None if mod_str == "torch": - out_type = torch.Tensor + out_type = type(self) elif mod_str in ("numpy", "np"): out_type = np.ndarray else: diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index c0e53fd24c..a12f519f62 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -434,8 +434,12 @@ def test_astype(self): for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.uint16): self.assertIsInstance(t.astype(np_types), np.ndarray) for pt_types in ("torch.float", torch.float, "torch.float64"): - self.assertIsInstance(t.astype(pt_types), torch.Tensor) - self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor) + result = t.astype(pt_types) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(result.meta.get("fname"), "filename") + result = t.astype("torch.float", device="cpu") + self.assertIsInstance(result, MetaTensor) + self.assertEqual(result.meta.get("fname"), "filename") def test_transforms(self): key = "im"