Source code for pyEdgeEval.common.multi_label.io

#!/usr/bin/env python3

import os.path as osp

from pyEdgeEval.utils import mkdir_or_exist


[docs]def save_sample_metrics( root_dir: str, sample_metrics, file_name: str = "eval_bdry_img.txt" ): file_path = osp.join(root_dir, file_name) tmp_line = "{name} {thrs:<10.6f} {rec:<10.6f} {prec:<10.6f} {f1:<10.6f}\n" with open(file_path, "w") as f: for res in sample_metrics: f.write( tmp_line.format( name=res["name"], thrs=res["threshold"], rec=res["recall"], prec=res["precision"], f1=res["f1"], ) )
[docs]def save_threshold_metrics( root_dir: str, threshold_metrics, file_name: str = "eval_bdry_thr.txt" ): file_path = osp.join(root_dir, file_name) tmp_line = "{thrs:<10.6f} {rec:<10.6f} {prec:<10.6f} {f1:<10.6f}\n" with open(file_path, "w") as f: for res in threshold_metrics: f.write( tmp_line.format( thrs=res["threshold"], rec=res["recall"], prec=res["precision"], f1=res["f1"], ) )
[docs]def save_overall_metric( root_dir: str, overall_metric, file_name: str = "eval_bdry.txt", ): file_path = osp.join(root_dir, file_name) with open(file_path, "w") as f: f.write( "{:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f} {:<10.6f}".format( overall_metric["ODS_threshold"], overall_metric["ODS_recall"], overall_metric["ODS_precision"], overall_metric["ODS_f1"], overall_metric["OIS_recall"], overall_metric["OIS_precision"], overall_metric["OIS_f1"], overall_metric["AUC"], overall_metric["AP"], ) )
[docs]def save_pretty_metrics( root_dir: str, class_table, summary_table, file_name: str = "results.txt", ): file_path = osp.join(root_dir, file_name) data = ( "Per Class Results:\n" + class_table + "\n" + "Summary:\n" + summary_table ) with open(file_path, "w") as f: f.write(data)
[docs]def save_category_results( root: str, category: int, sample_metrics, threshold_metrics, overall_metric, ): """Save per-category results""" cat_name = "class_" + str(category).zfill(3) cat_dir = osp.join(root, cat_name) mkdir_or_exist(cat_dir) if sample_metrics: save_sample_metrics(cat_dir, sample_metrics) if threshold_metrics: save_threshold_metrics(cat_dir, threshold_metrics) if overall_metric: save_overall_metric(cat_dir, overall_metric)