Source code for pyEdgeEval.visualization.pr_curve

#!/usr/bin/env python3

"""Try to mimic the original MATLAB PR Curves
Reference: https://github.com/pdollar/edges/blob/master/edgesEvalPlot.m
"""

import os
from collections import namedtuple
from typing import Any, List, Optional

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
from scipy.interpolate import interp1d


AlgorithmInfo = namedtuple(
    "AlgorithmInfo",
    [
        "name",  # name
        "threshold_results",  # path
        "overall_results",  # path
    ],
)


def _isometric_contour_line_template(
    ax=None,
):
    """Setup Basic Isometric Contour Line Plot"""
    if ax is None:
        ax = plt.gca()

    # plt.box(True)
    ax.set_frame_on(True)
    ax.grid(True)
    ax.axhline(0.5, 0, 1, linewidth=2, color=[0.7, 0.7, 0.7])
    for f in np.arange(0.1, 1, 0.1):
        r = np.arange(f, 1.01, 0.01)
        p = f * r / (2 * r - f)
        ax.plot(r, p, color=[0, 1, 0])
        ax.plot(p, r, color=[0, 1, 0])

    ax.set_xticks(np.linspace(0, 1, 11))
    ax.set_yticks(np.linspace(0, 1, 11))
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_aspect("equal", adjustable="box")
    ax.set(xlim=(0, 1), ylim=(0, 1))

    return ax


def _load_threshold_data(data_path):
    assert os.path.exists(data_path), f"ERR: {data_path} doesn't exist"
    pr = np.loadtxt(data_path)
    pr = pr[pr[:, 1] >= 1e-3]
    return pr


def _load_overall_results(data_path):
    assert os.path.exists(data_path), f"ERR: {data_path} doesn't exist"
    res = np.loadtxt(data_path)
    return res


def _calc_r50(pr):
    _, o = np.unique(pr[:, 2], return_index=True)
    r50 = interp1d(
        pr[o, 2],
        pr[o, 1],
        bounds_error=False,
        fill_value=np.nan,
    )(np.maximum(pr[o[0], 2], 0.5))

    return r50


[docs]def plot_pr_curve( algs: List[AlgorithmInfo], names: Optional[List[str]] = None, colors: Any = None, save_path: Optional[str] = None, ): assert isinstance(algs, list) and len(algs) > 0 n = len(algs) if names is None: names = [a.name for a in algs] else: assert isinstance(names, list) and len(names) == n names = np.array(names) if colors is None: # https://stackoverflow.com/questions/4971269/how-to-pick-a-new-color-for-each-plotted-line-within-a-figure-in-matplotlib colors = cm.rainbow(np.linspace(0, 1, n)) else: assert len(colors) == n colors = np.array(colors) # create basic template fig, ax = plt.subplots() ax = _isometric_contour_line_template(ax=ax) # load results for every algorithm (pr=[T, R, P, F]) n = len(algs) hs, res, prs = [None] * n, np.zeros((n, 9), dtype=np.float32), [] for i, alg in enumerate(algs): pr = _load_threshold_data(alg.threshold_results) res[i, :8] = _load_overall_results(alg.overall_results) res[i, 8] = _calc_r50(pr) prs.append(pr) prs = np.stack(prs, axis=0) # sort algorithms by ODS score o = np.argsort(res[:, 3])[::-1] res, prs, names, colors = res[o, :], prs[o], names[o], colors[o] # plot results for every algorithm (plot best last) for i in range(n - 1, -1, -1): hs[i] = ax.plot( prs[i, :, 1], prs[i, :, 2], linestyle="-", linewidth=3, color=colors[i], )[0] prefix = "ODS={:.3f}, OIS={:.3f}, AP={:.3f}, R50={:.3f}".format( *res[i, [3, 6, 7, 8]] ) prefix += " - {}".format(names[i]) print(prefix) # should remove # add legends legend_texts = [ "[F={:.2f}] {}".format(res[i, 3], names[i]) for i in range(n) ] ax.legend(hs, legend_texts, loc="lower left") if save_path: plt.savefig(save_path) # don't really need this, but calling this will close the figure plt.show()