|
|
|
|
@ -6,6 +6,58 @@ import torch
|
|
|
|
|
import cv2
|
|
|
|
|
import gc
|
|
|
|
|
import io
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _as_display_image(image):
|
|
|
|
|
arr = np.asarray(image)
|
|
|
|
|
if arr.ndim == 3 and arr.shape[0] in {1, 3, 4}:
|
|
|
|
|
arr = np.moveaxis(arr, 0, -1)
|
|
|
|
|
|
|
|
|
|
if np.issubdtype(arr.dtype, np.floating):
|
|
|
|
|
arr = np.nan_to_num(arr, nan=0.0, posinf=255.0, neginf=0.0)
|
|
|
|
|
if arr.size and (arr.max() > 1.0 or arr.min() < 0.0):
|
|
|
|
|
return np.clip(arr, 0, 255).astype(np.uint8)
|
|
|
|
|
return np.clip(arr, 0.0, 1.0)
|
|
|
|
|
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prune_old_inference_images(src, model_type, model_id, keep_last=200):
|
|
|
|
|
try:
|
|
|
|
|
keep_last = int(os.getenv("INFERENCE_IMAGE_KEEP_LAST", str(keep_last)))
|
|
|
|
|
except ValueError:
|
|
|
|
|
keep_last = keep_last
|
|
|
|
|
|
|
|
|
|
if keep_last <= 0 or not src or not os.path.isdir(src):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
pattern = re.compile(
|
|
|
|
|
r"_inference_(\d+)_.*_"
|
|
|
|
|
+ re.escape(str(model_id))
|
|
|
|
|
+ "_"
|
|
|
|
|
+ re.escape(str(model_type))
|
|
|
|
|
+ r"\.png$"
|
|
|
|
|
)
|
|
|
|
|
grouped = {}
|
|
|
|
|
for name in os.listdir(src):
|
|
|
|
|
match = pattern.match(name)
|
|
|
|
|
if match is None:
|
|
|
|
|
continue
|
|
|
|
|
grouped.setdefault(int(match.group(1)), []).append(name)
|
|
|
|
|
|
|
|
|
|
if len(grouped) <= keep_last:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
for old_result_id in sorted(grouped)[: len(grouped) - keep_last]:
|
|
|
|
|
for name in grouped[old_result_id]:
|
|
|
|
|
try:
|
|
|
|
|
os.remove(os.path.join(src, name))
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
pass
|
|
|
|
|
except OSError as exc:
|
|
|
|
|
print(f"failed to remove old inference image {name}: {exc}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _render_plot(values, figsize=(16, 16), dpi=16):
|
|
|
|
|
@ -164,9 +216,9 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
|
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
|
plt.ioff()
|
|
|
|
|
|
|
|
|
|
if int(ind_inference) <= 100 and isinstance(data, (list, tuple)) and len(data) >= 2:
|
|
|
|
|
if isinstance(data, (list, tuple)) and len(data) >= 2:
|
|
|
|
|
fig, ax = plt.subplots()
|
|
|
|
|
ax.imshow(np.moveaxis(data[0], 0, -1))
|
|
|
|
|
ax.imshow(_as_display_image(data[0]))
|
|
|
|
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_real_" + str(model_id) + "_" + model_type + ".png")
|
|
|
|
|
plt.clf()
|
|
|
|
|
plt.cla()
|
|
|
|
|
@ -175,7 +227,7 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
|
|
|
ax.imshow(np.moveaxis(data[1], 0, -1))
|
|
|
|
|
ax.imshow(_as_display_image(data[1]))
|
|
|
|
|
plt.savefig(src + "_inference_" + str(ind_inference) + "_" + prediction + "_mod_" + str(model_id) + "_" + model_type + ".png")
|
|
|
|
|
plt.clf()
|
|
|
|
|
plt.cla()
|
|
|
|
|
@ -183,6 +235,8 @@ def post_func_ensemble(src="", model_type="", prediction="", model_id=0, ind_inf
|
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
_prune_old_inference_images(src, model_type, model_id)
|
|
|
|
|
|
|
|
|
|
plt.clf()
|
|
|
|
|
plt.cla()
|
|
|
|
|
plt.close()
|
|
|
|
|
|