Source code for pyEdgeEval.evaluators.base

#!/usr/bin/env python3

import json
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from copy import deepcopy

import numpy as np
from prettytable import PrettyTable

from pyEdgeEval.common.multi_label import (
    save_overall_metric,
    save_pretty_metrics,
)
from pyEdgeEval.utils import print_log


[docs]class BaseEvaluator(object, metaclass=ABCMeta): # Global variables dataset_root = None pred_root = None split = None # Hidden variables _sample_names = None # don't make this mutable (e.g. []) _logger = "pyEdgeEval" @property def sample_names(self): if self._sample_names is None: print_log("sample_names is None", logger=self._logger) return None else: # avoid changing inplace return deepcopy(self._sample_names)
[docs] @abstractmethod def set_sample_names(self, **kwargs): """Placeholder for setting sample_names""" pass
[docs] @abstractmethod def set_eval_params(self, **kwargs): """Placeholder for setting evaluation parameters""" pass
[docs] @abstractmethod def eval_params(self): """Placeholder for getting evaluation parameters""" pass
[docs] @abstractmethod def evaluate(self, **kwargs): """Placeholder for the main evaluation function""" pass
[docs]class BaseBinaryEvaluator(BaseEvaluator): ...
[docs]class BaseMultilabelEvaluator(BaseEvaluator): CLASSES = None @abstractmethod def _before_evaluation(self): pass
[docs] @abstractmethod def evaluate_category(self, **kwargs): """Placeholder for the category evaluation function""" pass
[docs] def evaluate( self, categories, thresholds, nproc, save_dir, print_metrics=( "ODS_threshold", "ODS_recall", "ODS_precision", "ODS_f1", "AUC", ), ): self._before_evaluation() # print evaluation params pretty_eval_params = json.dumps( self.eval_params, sort_keys=False, indent=4 ) print_log(pretty_eval_params, logger=self._logger) # check number of categories (indexed from 1) if isinstance(categories, int): categories = [categories] if categories is None: categories = list(range(1, len(self.CLASSES) + 1)) assert isinstance( categories, list ), f"ERR: `categories` should be a list, but got {type(categories)}" assert len(categories) > 0, "ERR: 0 categories" ret_metrics = OrderedDict() for category in categories: # do a single category evaluation overall_metric = self.evaluate_category( category=category, thresholds=thresholds, nproc=nproc, save_dir=save_dir, ) for k, v in overall_metric.items(): if k in ret_metrics.keys(): ret_metrics[k].append(v) else: ret_metrics[k] = [v] if save_dir: # output basic metrics (unformatted) ret_metrics_out = OrderedDict( { ret_metric: np.round(np.nanmean(ret_metric_value), 6) for ret_metric, ret_metric_value in ret_metrics.items() } ) save_overall_metric( root_dir=save_dir, overall_metric=ret_metrics_out, file_name="overall_eval_bdry.txt", ) # isolate metrics that are not percentage ret_thresholds = ret_metrics.pop("ODS_threshold", None) # summary table ret_metrics_summary = OrderedDict( { ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 3) for ret_metric, ret_metric_value in ret_metrics.items() } ) # each class table remove_metrics = ["OIS_recall", "OIS_precision", "OIS_f1", "AUC", "AP"] for m in remove_metrics: ret_metrics.pop(m, None) class_names = [self.CLASSES[i - 1] for i in categories] ret_metrics_class = OrderedDict( { ret_metric: np.round(np.array(ret_metric_value) * 100, 3) for ret_metric, ret_metric_value in ret_metrics.items() } ) ret_metrics_class.update( {"ODS_threshold": np.round(np.array(ret_thresholds), 2)} ) ret_metrics_class.update({"Class": class_names}) ret_metrics_class.move_to_end("Class", last=False) # for logger class_table_data = PrettyTable() for key, val in ret_metrics_class.items(): if key in print_metrics or key == "Class": class_table_data.add_column(key, val) summary_table_data = PrettyTable() for key, val in ret_metrics_summary.items(): if key in print_metrics: summary_table_data.add_column(key, [val]) print_log("per class results:", logger=self._logger) print_log("\n" + class_table_data.get_string(), logger=self._logger) print_log("Summary:", logger=self._logger) print_log("\n" + summary_table_data.get_string(), logger=self._logger) # save total metrics if save_dir: save_pretty_metrics( root_dir=save_dir, class_table=class_table_data.get_string(), summary_table=summary_table_data.get_string(), )