Source code for pyEdgeEval.common.multi_label.edge_encoding

#!/usr/bin/env python3

"""Encoding functions for multi-label edges

- `default`: encodes into binary format
- `RGB`: encodes into binary format that is compatible with image fs
"""

import numpy as np


[docs]def default_multilabel_encoding(edges: np.ndarray): """Encode multi-label edges to binary format For now we use uint32 as the base (PIL can save uint32). Therefore, we can save at most 32 classes. However, RGB encoding is efficient unless you need more classes. """ num_classes, h, w = edges.shape cat_edge_map = np.zeros((h, w), dtype=np.uint32) for trainId in range(num_classes): edge_map = edges[trainId] cat_edge_map = cat_edge_map + (2**trainId) * edge_map return cat_edge_map
[docs]def rgb_multilabel_encoding(edges: np.ndarray): """Encode multi-label edges to RGB format Each channel is 8-bit, so the RGB format can encode the edges into 24-bit (maximum of 24 classes). This format is useful for training data where edges need to be transformed (compatible with various 3-channel augmentations). """ num_classes, h, w = edges.shape cat_edge_b = np.zeros((h, w), dtype=np.uint8) cat_edge_g = np.zeros((h, w), dtype=np.uint8) cat_edge_r = np.zeros((h, w), dtype=np.uint8) cat_edge_png = np.zeros((h, w, 3), dtype=np.uint8) for trainId in range(num_classes): edge_map = edges[trainId] if trainId >= 0 and trainId < 8: cat_edge_b = cat_edge_b + (2**trainId) * edge_map elif trainId >= 8 and trainId < 16: cat_edge_g = cat_edge_g + (2 ** (trainId - 8)) * edge_map elif trainId >= 16 and trainId < 24: cat_edge_r = cat_edge_r + (2 ** (trainId - 16)) * edge_map else: raise ValueError() cat_edge_png[:, :, 0] = cat_edge_r cat_edge_png[:, :, 1] = cat_edge_g cat_edge_png[:, :, 2] = cat_edge_b return cat_edge_png