Match training colormap for v44 inference images

Automatica_1_v2
Sergey Revyakin 5 hours ago
parent 67e975c49d
commit 8ed0974445

@ -10,18 +10,16 @@ import os
import re import re
def _as_display_image(image): def _as_training_colormap_image(image):
arr = np.asarray(image) arr = np.asarray(image)
if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}:
arr = np.moveaxis(arr, 0, -1) arr = np.moveaxis(arr, 0, -1)
if np.issubdtype(arr.dtype, np.floating): if arr.ndim == 3:
arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0) arr = arr[..., :3].astype(np.float32)
if arr.size and (arr.max() > 1.0 or arr.min() < 0.0): arr = 0.299 * arr[..., 0] + 0.587 * arr[..., 1] + 0.114 * arr[..., 2]
return np.clip(arr, 0, 255).astype(np.uint8)
return np.clip(arr, 0.0, 1.0)
return arr return np.nan_to_num(arr.astype(np.float32), nan=0.0, posinf=255.0, neginf=0.0)
def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)): def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)):
@ -246,7 +244,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
if isinstance(data, (list, tuple)) and len(data) >= 2: if isinstance(data, (list, tuple)) and len(data) >= 2:
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(_as_display_image(data[0])) ax.imshow(_as_training_colormap_image(data[0]), cmap="viridis")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
plt.clf() plt.clf()
plt.cla() plt.cla()
@ -255,7 +253,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
gc.collect() gc.collect()
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(_as_display_image(data[1])) ax.imshow(_as_training_colormap_image(data[1]), cmap="viridis")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
plt.clf() plt.clf()
plt.cla() plt.cla()

@ -10,18 +10,16 @@ import os
import re import re
def _as_display_image(image): def _as_training_colormap_image(image):
arr = np.asarray(image) arr = np.asarray(image)
if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}:
arr = np.moveaxis(arr, 0, -1) arr = np.moveaxis(arr, 0, -1)
if np.issubdtype(arr.dtype, np.floating): if arr.ndim == 3:
arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0) arr = arr[..., :3].astype(np.float32)
if arr.size and (arr.max() > 1.0 or arr.min() < 0.0): arr = 0.299 * arr[..., 0] + 0.587 * arr[..., 1] + 0.114 * arr[..., 2]
return np.clip(arr, 0, 255).astype(np.uint8)
return np.clip(arr, 0.0, 1.0)
return arr return np.nan_to_num(arr.astype(np.float32), nan=0.0, posinf=255.0, neginf=0.0)
def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)): def _render_signal_channel(values, figsize=(16, 16), dpi=16, resize=(256, 256)):
@ -246,7 +244,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
if isinstance(data, (list, tuple)) and len(data) >= 2: if isinstance(data, (list, tuple)) and len(data) >= 2:
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(_as_display_image(data[0])) ax.imshow(_as_training_colormap_image(data[0]), cmap="viridis")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
plt.clf() plt.clf()
plt.cla() plt.cla()
@ -255,7 +253,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
gc.collect() gc.collect()
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(_as_display_image(data[1])) ax.imshow(_as_training_colormap_image(data[1]), cmap="viridis")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
plt.clf() plt.clf()
plt.cla() plt.cla()

@ -10,18 +10,16 @@ import os
import re import re
def _as_display_image(image): def _as_training_colormap_image(image):
arr = np.asarray(image) arr = np.asarray(image)
if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}: if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}:
arr = np.moveaxis(arr, 0, -1) arr = np.moveaxis(arr, 0, -1)
if np.issubdtype(arr.dtype, np.floating): if arr.ndim == 3:
arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0) arr = arr[..., :3].astype(np.float32)
if arr.size and (arr.max() > 1.0 or arr.min() < 0.0): arr = 0.299 * arr[..., 0] + 0.587 * arr[..., 1] + 0.114 * arr[..., 2]
return np.clip(arr, 0, 255).astype(np.uint8)
return np.clip(arr, 0.0, 1.0)
return arr return np.nan_to_num(arr.astype(np.float32), nan=0.0, posinf=255.0, neginf=0.0)
def _prune_old_inference_images(src, model_type, model_id, keep_last=200): def _prune_old_inference_images(src, model_type, model_id, keep_last=200):
@ -218,7 +216,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
if isinstance(data, (list, tuple)) and len(data) >= 2: if isinstance(data, (list, tuple)) and len(data) >= 2:
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(_as_display_image(data[0])) ax.imshow(_as_training_colormap_image(data[0]), cmap="viridis")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
plt.clf() plt.clf()
plt.cla() plt.cla()
@ -227,7 +225,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
gc.collect() gc.collect()
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(_as_display_image(data[1])) ax.imshow(_as_training_colormap_image(data[1]), cmap="viridis")
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png") plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
plt.clf() plt.clf()
plt.cla() plt.cla()

Loading…
Cancel
Save