Source code for pyEdgeEval.evaluators.sbd

#!/usr/bin/env python3

import os.path as osp

from pyEdgeEval.common.multi_label import (
    calculate_metrics,
    save_category_results,
)
from pyEdgeEval.datasets import sbd_eval_single
from pyEdgeEval.utils import print_log

from .base import BaseMultilabelEvaluator


[docs]class SBDEvaluator(BaseMultilabelEvaluator): """SBD evaluator""" # Dataset specific attributes CLASSES = ( "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ) CLS_DIR = "cls" INST_DIR = "inst" GT_SUFFIX = ".mat" PRED_SUFFIX = ".bmp" def __init__( self, dataset_root: str, pred_root: str, split: str = "val", **kwargs, ): self.dataset_root = dataset_root self.pred_root = pred_root assert split in ("val", "test") self.split = split self.CLS_root = osp.join( self.dataset_root, self.CLS_DIR, ) self.INST_root = osp.join( self.dataset_root, self.INST_DIR, ) # we can try to load some sample names try: self.set_sample_names() except Exception: print_log( "Tried to set sample_names, but couldn't", logger=self._logger, )
[docs] def set_sample_names(self, sample_names=None, split_file=None): """priortizes `sample_names` more than `split_file`""" if sample_names is None: # load sample_names from split file if split_file is None: split_file = osp.join(self.dataset_root, f"{self.split}.txt") assert osp.exists(split_file), f"ERR: {split_file} does not exist!" print_log(f"Loading samples from {split_file}", logger=self._logger) with open(split_file, "r") as f: sample_names = f.read().splitlines() assert isinstance( sample_names, list ), f"ERR: sample_names should be a list but got {type(sample_names)}" assert len(sample_names) > 0, "ERR: sample_names is empty" self._sample_names = sample_names
[docs] def set_eval_params( self, eval_mode=None, scale: float = 1.0, apply_thinning: bool = False, apply_nms: bool = False, instance_sensitive: bool = True, max_dist: float = 0.02, skip_if_nonexistent: bool = False, kill_internal: bool = False, **kwargs, ) -> None: assert 0 < scale <= 1, f"ERR: scale ({scale}) is not valid" self.scale = scale self.apply_thinning = apply_thinning self.apply_nms = apply_nms if eval_mode == "pre-seal": print_log("Using Pre-SEAL params", logger=self._logger) self.max_dist = 0.02 self.kill_internal = True self.skip_if_nonexistent = True self.instance_sensitive = False elif eval_mode == "post-seal": print_log("Using Post-SEAL params", logger=self._logger) self.max_dist = 0.02 self.kill_internal = False self.skip_if_nonexistent = False self.instance_sensitive = True elif eval_mode == "high-quality": print_log( "Using params for high-quality annotations", logger=self._logger ) self.max_dist = 0.0075 self.kill_internal = False self.skip_if_nonexistent = False self.instance_sensitive = True else: print_log("Using custom params", logger=self._logger) self.max_dist = max_dist self.kill_internal = kill_internal self.skip_if_nonexistent = skip_if_nonexistent self.instance_sensitive = instance_sensitive if self.kill_internal and self.instance_sensitive: print_log( "kill_internal and instance_sensitive are both True which will conflict with each either", logger=self._logger, )
@property def eval_params(self): return dict( scale=self.scale, apply_thinning=self.apply_thinning, apply_nms=self.apply_nms, max_dist=self.max_dist, kill_internal=self.kill_internal, skip_if_nonexistent=self.skip_if_nonexistent, ) def _before_evaluation(self): assert osp.exists( self.dataset_root ), f"ERR: {self.dataset_root} does not exist" assert osp.exists(self.CLS_root), f"ERR: {self.CLS_root} does not exist" assert osp.exists( self.INST_root ), f"ERR: {self.INST_root} does not exist" assert osp.exists( self.pred_root ), f"ERR: {self.pred_root} does not exist"
[docs] def evaluate_category( self, category, thresholds, nproc, save_dir, ): self._before_evaluation() assert ( 0 < category < len(self.CLASSES) + 1 ), f"ERR: category={category} is not in range ({len(self.CLASSES) + 1})" # populate data (samples) data = [] for sample_name in self.sample_names: # FIXME: naming scheme differs sometimes category_dir = f"class_{str(category).zfill(3)}" # GT file paths cls_path = osp.join(self.CLS_root, f"{sample_name}{self.GT_SUFFIX}") if self.instance_sensitive: inst_path = osp.join( self.INST_root, f"{sample_name}{self.GT_SUFFIX}" ) assert osp.exists(inst_path), f"ERR: {inst_path} is not valid" else: inst_path = None assert osp.exists(cls_path), f"ERR: {cls_path} is not valid" # prediction file path # sample_name = sample_name.split("/")[1] # assert city/img pred_path = osp.join( self.pred_root, category_dir, f"{sample_name}{self.PRED_SUFFIX}", ) data.append( dict( name=sample_name, category=category, thresholds=thresholds, # file paths cls_path=cls_path, inst_path=inst_path, pred_path=pred_path, **self.eval_params, ) ) assert len(data) > 0, "ERR: no evaluation data" # evaluate (sample_metrics, threshold_metrics, overall_metric) = calculate_metrics( eval_single=sbd_eval_single, thresholds=thresholds, samples=data, nproc=nproc, ) # save metrics if save_dir: print_log("Saving category results", logger=self._logger) save_category_results( root=save_dir, category=category, sample_metrics=sample_metrics, threshold_metrics=threshold_metrics, overall_metric=overall_metric, ) return overall_metric