diff --git a/src/pytom_tm/tmjob.py b/src/pytom_tm/tmjob.py index 3200399d..566d3f51 100644 --- a/src/pytom_tm/tmjob.py +++ b/src/pytom_tm/tmjob.py @@ -44,6 +44,10 @@ def load_json_to_tmjob( with open(file_name, "r") as fstream: data = json.load(fstream) + # wrangle dtypes + output_dtype = data.get("output_dtype", "float32") + output_dtype = np.dtype(output_dtype) + job = TMJob( data["job_key"], data["log_level"], @@ -74,7 +78,7 @@ def load_json_to_tmjob( particle_diameter=data.get("particle_diameter", None), random_phase_correction=data.get("random_phase_correction", False), rng_seed=data.get("rng_seed", 321), - output_dtype=data.get("output_dtype", np.float32), + output_dtype=output_dtype, ) # if the file originates from an old version set the phase shift for compatibility if ( @@ -527,6 +531,8 @@ def write_to_json(self, file_name: pathlib.Path) -> None: for key, value in d.items(): if isinstance(value, pathlib.Path): d[key] = str(value) + if isinstance(value, np.dtype): + d[key] = str(value) with open(file_name, "w") as fstream: json.dump(d, fstream, indent=4)