Source code for pyEdgeEval.helpers.evaluate_sbd

#!/usr/bin/env python3

import argparse
import os.path as osp
import time

from pyEdgeEval.evaluators.sbd import SBDEvaluator
from pyEdgeEval.utils import get_root_logger, mkdir_or_exist

__all__ = ["evaluate_sbd"]


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate SBD output")
    parser.add_argument(
        "sbd_path",
        type=str,
        help="the root path of the SBD dataset",
    )
    parser.add_argument(
        "pred_path",
        type=str,
        help="the root path of the predictions",
    )
    parser.add_argument(
        "--output-path",
        type=str,
        help="the root path of where the results are populated",
    )
    parser.add_argument(
        "--categories",
        type=str,
        default="[15]",
        help="the category number to evaluate; can be multiple values'[15]'",
    )
    parser.add_argument(
        "--thresholds",
        type=str,
        default="99",
        help="the number of thresholds (could be a list of floats); use 99 for eval",
    )
    parser.add_argument(
        "--raw",
        action="store_true",
        help="option to remove the thinning process (i.e. uses raw predition)",
    )
    parser.add_argument(
        "--apply-nms",
        action="store_true",
        help="applies NMS before evaluation",
    )
    parser.add_argument(
        "--kill-internal",
        action="store_true",
        help="kill internal contour",
    )
    parser.add_argument(
        "--nproc",
        type=int,
        default=4,
        help="the number of parallel threads",
    )

    return parser.parse_args()


def evaluate(
    sbd_path: str,
    pred_path: str,
    output_path: str,
    categories: str,
    apply_thinning: bool,
    apply_nms: bool,
    thresholds: str,
    nproc: int,
):
    """Evaluate SBD"""

    if categories is None:
        print("use all categories")
        categories = list(range(1, len(SBDEvaluator.CLASSES) + 1))
    else:
        # string evaluation for categories
        categories = categories.strip()
        try:
            categories = [int(categories)]
        except ValueError:
            try:
                if categories.startswith("[") and categories.endswith("]"):
                    categories = categories[1:-1]
                    categories = [
                        int(cat.strip()) for cat in categories.split(",")
                    ]
                else:
                    print(
                        "Bad categories format; should be a python list of floats (`[a, b, c]`)"
                    )
                    return
            except ValueError:
                print(
                    "Bad categories format; should be a python list of ints (`[a, b, c]`)"
                )
                return

    for cat in categories:
        assert (
            0 < cat < 21
        ), f"category needs to be between 1 ~ 19, but got {cat}"

    # string evaluation for thresholds
    thresholds = thresholds.strip()
    try:
        n_thresholds = int(thresholds)
        thresholds = n_thresholds
    except ValueError:
        try:
            if thresholds.startswith("[") and thresholds.endswith("]"):
                thresholds = thresholds[1:-1]
                thresholds = [float(t.strip()) for t in thresholds.split(",")]
            else:
                print(
                    "Bad threshold format; should be a python list of floats (`[a, b, c]`)"
                )
                return
        except ValueError:
            print(
                "Bad threshold format; should be a python list of ints (`[a, b, c]`)"
            )
            return

    if output_path is None:
        timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
        output_path = osp.join(
            osp.normpath(pred_path), f"edge_results_{timestamp}"
        )

    mkdir_or_exist(output_path)

    timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    log_file = osp.join(output_path, f"{timestamp}.log")
    logger = get_root_logger(log_file=log_file, log_level="INFO")
    logger.info("Running SBD Evaluation")
    logger.info(f"categories: \t{categories}")
    logger.info(f"thresholds: \t{thresholds}")
    logger.info(f"thin:       \t{apply_thinning}")
    logger.info(f"nms:        \t{apply_nms}")
    print("\n\n")

    # initialize evaluator
    evaluator = SBDEvaluator(
        dataset_root=sbd_path,
        pred_root=pred_path,
    )
    if evaluator.sample_names is None:
        # load custom sample names
        # SAMPLE_NAMES = ["2008_000051", "2008_000195"]
        evaluator.set_sample_names()

    # set parameters
    eval_mode = "pre-seal"  # FIXME: hard-coded for now
    evaluator.set_eval_params(
        eval_mode=eval_mode,
        scale=1.0,
        apply_thinning=apply_thinning,
        apply_nms=apply_nms,
        instance_sensitive=False,
    )

    # evaluate
    evaluator.evaluate(
        categories=categories,
        thresholds=thresholds,
        nproc=nproc,
        save_dir=output_path,
    )


[docs]def evaluate_sbd(): args = parse_args() # by default, we apply thinning apply_thinning = not args.raw evaluate( sbd_path=args.sbd_path, pred_path=args.pred_path, output_path=args.output_path, categories=args.categories, apply_thinning=apply_thinning, apply_nms=args.apply_nms, thresholds=args.thresholds, nproc=args.nproc, )