#!/usr/bin/env python3
"""Encoding functions for multi-label edges
- `default`: encodes into binary format
- `RGB`: encodes into binary format that is compatible with image fs
"""
import numpy as np
[docs]def default_multilabel_encoding(edges: np.ndarray):
    """Encode multi-label edges to binary format
    For now we use uint32 as the base (PIL can save uint32).
    Therefore, we can save at most 32 classes.
    However, RGB encoding is efficient unless you need more classes.
    """
    num_classes, h, w = edges.shape
    cat_edge_map = np.zeros((h, w), dtype=np.uint32)
    for trainId in range(num_classes):
        edge_map = edges[trainId]
        cat_edge_map = cat_edge_map + (2**trainId) * edge_map
    return cat_edge_map 
[docs]def rgb_multilabel_encoding(edges: np.ndarray):
    """Encode multi-label edges to RGB format
    Each channel is 8-bit, so the RGB format can encode the edges into 24-bit
    (maximum of 24 classes).
    This format is useful for training data where edges need to be transformed
    (compatible with various 3-channel augmentations).
    """
    num_classes, h, w = edges.shape
    cat_edge_b = np.zeros((h, w), dtype=np.uint8)
    cat_edge_g = np.zeros((h, w), dtype=np.uint8)
    cat_edge_r = np.zeros((h, w), dtype=np.uint8)
    cat_edge_png = np.zeros((h, w, 3), dtype=np.uint8)
    for trainId in range(num_classes):
        edge_map = edges[trainId]
        if trainId >= 0 and trainId < 8:
            cat_edge_b = cat_edge_b + (2**trainId) * edge_map
        elif trainId >= 8 and trainId < 16:
            cat_edge_g = cat_edge_g + (2 ** (trainId - 8)) * edge_map
        elif trainId >= 16 and trainId < 24:
            cat_edge_r = cat_edge_r + (2 ** (trainId - 16)) * edge_map
        else:
            raise ValueError()
    cat_edge_png[:, :, 0] = cat_edge_r
    cat_edge_png[:, :, 1] = cat_edge_g
    cat_edge_png[:, :, 2] = cat_edge_b
    return cat_edge_png