Source code for saf.process.shannon_mask

# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
Mask data based using a length-relative normalized version of the Shannon Index as an indicator of entropy.
"""
from __future__ import annotations

import logging
import math
import string
from typing import Any
from typing import AsyncIterator
from typing import Optional
from typing import Type

from pydantic import Field

from saf.models import CollectedEvent
from saf.models import PipelineRunContext
from saf.models import ProcessConfigBase

log = logging.getLogger(__name__)


[docs]class ShannonMaskProcessConfig(ProcessConfigBase): """ Configuration schema for the Shannon mask processor plugin. """ mask_str: str = Field("HIGH-ENTROPY") mask_char: Optional[str] = Field(None, min_length=1, max_length=1) mask_prefix: str = Field("<:") mask_suffix: str = Field(":>") h_threshold: float = Field(0.9, ge=0.0, le=1.0) length_threshold: int = Field(16, gt=1) delimeter: str = Field(" ", min_length=1, max_length=1) alphabet: str = Field(f"{string.ascii_letters}{string.digits}+/=", min_length=1)
[docs]def get_config_schema() -> Type[ProcessConfigBase]: """ Get the Shannon mask processor plugin configuration schema. """ return ShannonMaskProcessConfig
def _calculate_normalized_shannon_index(word: str, alphabet: str) -> float: """ Calculate a length-relative normalized Shannon index of the event_piece. Shannon Diversity index: https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/shannon.htm """ # pylint: disable=invalid-name word_len = len(word) alphabet_len = len(alphabet) p_dict = {i: word.count(i) / word_len for i in word if i in alphabet} # h is the standard Shannon Index naming convention h = sum(-1 * p_i * math.log(p_i) for p_i in p_dict.values()) # Quotient-Remainder Thm: We have integers d, r such that # len(word) = d * len(alphabet) + r where 0 <= r < len(alphabet) # We can use this relationship to find h_max for a given string length. d = word_len // alphabet_len r = word_len % alphabet_len p_r = (d + 1) / word_len p_d = d / word_len h_max = -(r * p_r * math.log(p_r)) - ((alphabet_len - r) * p_d * math.log(p_d)) return h / h_max # pylint: enable=invalid-name def _shannon_mask(event_piece: str, config: ShannonMaskProcessConfig) -> str: """ Go through the string and process based on normalized Shannon index values. """ def repl_fn(word: str) -> str: """ The replacement function to be called on each match. If a mask_char was provided, use that with matching length. Otherwise, use the config.mask_str surrounded by prefix and suffix. """ if config.mask_char: return config.mask_char * len(word) return f"{config.mask_prefix}{config.mask_str}{config.mask_suffix}" orig_str = event_piece try: split_piece = event_piece.split(config.delimeter) for word in split_piece: if len(word) >= config.length_threshold: h_norm = _calculate_normalized_shannon_index(word, config.alphabet) if h_norm > config.h_threshold: event_piece = event_piece.replace(word, repl_fn(word)) except Exception: log.exception("Failed to mask value '%s'", orig_str) return event_piece def _shannon_process(obj: Any, config: ShannonMaskProcessConfig) -> Any: # noqa: ANN401 """ Recursive method to iterate over dictionary and apply rules to all str values. """ # Iterate over all attributes of obj. If string, do mask. If dict, recurse. Else, do nothing. if isinstance(obj, str): return _shannon_mask(obj, config) if isinstance(obj, (list, tuple, set)): klass = type(obj) return klass(_shannon_process(i, config) for i in obj) if isinstance(obj, dict): for key, value in obj.items(): obj[key] = _shannon_process(value, config) return obj
[docs]async def process( *, ctx: PipelineRunContext[ShannonMaskProcessConfig], event: CollectedEvent, ) -> AsyncIterator[CollectedEvent]: """ Method called to mask the data based on normalized Shannon index values. """ config = ctx.config log.debug("Processing event in shannon_mask: %s", event.model_dump_json()) event_dict = event.model_dump() processed_event_dict = _shannon_process(event_dict, config) yield event.model_validate(processed_event_dict)