diff --git a/NN_server/Models/ensemble_1200_v44.py b/NN_server/Models/ensemble_1200_v44.py index aa1b561..db22dfa 100644 --- a/NN_server/Models/ensemble_1200_v44.py +++ b/NN_server/Models/ensemble_1200_v44.py @@ -10,6 +10,20 @@ import os import re +def _as_display_image(image): + arr = np.asarray(image) + if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: + arr = np.moveaxis(arr, 0, -1) + + if np.issubdtype(arr.dtype, np.floating): + arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0) + if arr.size and (arr.max() > 1.0 or arr.min() < 0.0): + return np.clip(arr, 0, 255).astype(np.uint8) + return np.clip(arr, 0.0, 1.0) + + return arr + + def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)): import matplotlib.pyplot as plt @@ -232,7 +246,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf if isinstance(data, (list, tuple)) and len(data) >= 2: fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[0], 0, -1)) + ax.imshow(_as_display_image(data[0])) plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() @@ -241,7 +255,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf gc.collect() fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[1], 0, -1)) + ax.imshow(_as_display_image(data[1])) plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() diff --git a/NN_server/Models/ensemble_2400_v44.py b/NN_server/Models/ensemble_2400_v44.py index aa1b561..db22dfa 100644 --- a/NN_server/Models/ensemble_2400_v44.py +++ b/NN_server/Models/ensemble_2400_v44.py @@ -10,6 +10,20 @@ import os import re +def _as_display_image(image): + arr = np.asarray(image) + if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: + arr = np.moveaxis(arr, 0, -1) + + if np.issubdtype(arr.dtype, np.floating): + arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0) + if arr.size and (arr.max() > 1.0 or arr.min() < 0.0): + return np.clip(arr, 0, 255).astype(np.uint8) + return np.clip(arr, 0.0, 1.0) + + return arr + + def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)): import matplotlib.pyplot as plt @@ -232,7 +246,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf if isinstance(data, (list, tuple)) and len(data) >= 2: fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[0], 0, -1)) + ax.imshow(_as_display_image(data[0])) plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() @@ -241,7 +255,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf gc.collect() fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[1], 0, -1)) + ax.imshow(_as_display_image(data[1])) plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() diff --git a/NN_server/Models/ensemble_915_v44.py b/NN_server/Models/ensemble_915_v44.py index d805a4f..a4067de 100644 --- a/NN_server/Models/ensemble_915_v44.py +++ b/NN_server/Models/ensemble_915_v44.py @@ -8,6 +8,20 @@ import gc import io +def _as_display_image(image): + arr = np.asarray(image) + if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: + arr = np.moveaxis(arr, 0, -1) + + if np.issubdtype(arr.dtype, np.floating): + arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0) + if arr.size and (arr.max() > 1.0 or arr.min() < 0.0): + return np.clip(arr, 0, 255).astype(np.uint8) + return np.clip(arr, 0.0, 1.0) + + return arr + + def _render_plot(values, figsize=(16, 16), dpi=16): import matplotlib.pyplot as plt @@ -166,7 +180,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2: fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[0], 0, -1)) + ax.imshow(_as_display_image(data[0])) plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla() @@ -175,7 +189,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf gc.collect() fig, ax = plt.subplots() - ax.imshow(np.moveaxis(data[1], 0, -1)) + ax.imshow(_as_display_image(data[1])) plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.clf() plt.cla()