#!/usr/bin/env python3
import numpy as np
[docs]def mask2onehot(mask, labels):
"""
Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
hot encoding vector
"""
c = mask.shape[0]
assert (
len(labels) > 0
), "`labels` should be a list with more than 1 elements"
assert c >= len(
labels
), "tried to convert into onehot with more labels than the original mask"
_mask = [mask == i for i in labels]
return np.array(_mask).astype(np.uint8)
[docs]def edge_multilabel2binary(edges: np.ndarray) -> np.ndarray:
"""
Converts multilabel edge to binary edge data (collapse multi-label)
"""
return (np.sum(edges, axis=0) > 0).astype(np.uint8)
[docs]def edge_onehot2multilabel(edges: np.ndarray) -> np.ndarray:
"""
Converts multilabel edges to encoded single channel edge data
while preserving multi-label
"""
labels, h, w = edges.shape
edge_map = np.zeros((h, w), dtype=np.uint32)
for l in range(labels):
m = edges[l]
edge_map = edge_map + (2**l) * m
return edge_map
[docs]def mask_label2trainId(mask: np.ndarray, label2trainId: dict) -> np.ndarray:
"""Python version of `labelid2trainid` function for segmentation data
Args:
mask: single channel image containing segmentation label
Returns:
np.ndarray
"""
if len(mask.shape) == 2:
h, w = mask.shape
elif len(mask.shape) == 3:
h, w, c = mask.shape
assert c == 1, f"ERR: input label has {c} channels which should be 1"
else:
raise ValueError()
# 1. create an array populated with 255 (background pixel)
trainId_mask = 255 * np.ones((h, w), dtype=np.uint8) # 8-bit array
# 2. map all pixels from `label` to `trainId`
for labelId, trainId in label2trainId.items():
idx = mask == labelId
trainId_mask[idx] = trainId
return trainId_mask
[docs]def edge_label2trainId(edge: np.ndarray, label2trainId: dict) -> np.ndarray:
assert (
len(edge.shape) == 3
), f"ERR: should be 3 channel input but got {edge.shape}"
_, h, w = edge.shape
edges_trainIds = np.zeros((len(label2trainId), h, w), dtype=np.uint8)
for labelId, trainId in label2trainId.items():
edges_trainIds[trainId] = edge[labelId, ...]
return edges_trainIds