From e34f030f9b250dc73d2a86c94a1aad26eae09fdb Mon Sep 17 00:00:00 2001 From: Igor Jankowski Date: Fri, 6 Mar 2026 19:04:08 +0100 Subject: [PATCH] Refactor reward system, fix action preconditions, and extract metrics safely --- marl_cyborg/actions/blue/mitigation.py | 13 +- marl_cyborg/actions/red/exploits.py | 9 +- .../actions/red/privilege_escalation.py | 47 +++-- marl_cyborg/environment/parallel_env.py | 68 +++++++- marl_cyborg/scenarios/ransomware.py | 162 ++++++++++++++---- 5 files changed, 227 insertions(+), 72 deletions(-) diff --git a/marl_cyborg/actions/blue/mitigation.py b/marl_cyborg/actions/blue/mitigation.py index ba5acfd..bc39cd5 100644 --- a/marl_cyborg/actions/blue/mitigation.py +++ b/marl_cyborg/actions/blue/mitigation.py @@ -82,8 +82,12 @@ def execute(self, global_state) -> ActionEffect: """ return ActionEffect( success=True, - state_deltas={f'hosts/{self.target_ip}/status': 'online'}, - observation_data={'alert': 'Host restored.'}, + state_deltas={ + f'hosts/{self.target_ip}/status': 'online', + f'hosts/{self.target_ip}/privilege': 'None', + f'hosts/{self.target_ip}/compromised_by': 'None', + }, + observation_data={'alert': 'Host restored and cleaned.'}, ) @@ -125,7 +129,10 @@ def execute(self, global_state) -> ActionEffect: """ return ActionEffect( success=True, - state_deltas={f'hosts/{self.target_ip}/privilege': 'None'}, + state_deltas={ + f'hosts/{self.target_ip}/privilege': 'None', + f'hosts/{self.target_ip}/compromised_by': 'None', + }, observation_data={'alert': 'Unauthorized access removed.'}, ) diff --git a/marl_cyborg/actions/red/exploits.py b/marl_cyborg/actions/red/exploits.py index 56bacc6..bb7bd9a 100644 --- a/marl_cyborg/actions/red/exploits.py +++ b/marl_cyborg/actions/red/exploits.py @@ -242,14 +242,7 @@ def __init__(self, agent_id: str, target_ip: str): super().__init__(agent_id, target_ip=target_ip) def validate(self, global_state) -> bool: - """Ascertains valid routing bounds to the simulated web interface. - - Args: - global_state (GlobalNetworkState): Simulator reference parameter. - - Returns: - bool: Availability check result. - """ + """Requires valid routing to the web interface.""" return global_state.can_route_to(self.target_ip) def execute(self, global_state) -> ActionEffect: diff --git a/marl_cyborg/actions/red/privilege_escalation.py b/marl_cyborg/actions/red/privilege_escalation.py index bf43839..1b5906b 100644 --- a/marl_cyborg/actions/red/privilege_escalation.py +++ b/marl_cyborg/actions/red/privilege_escalation.py @@ -18,15 +18,16 @@ def __init__(self, agent_id: str, target_ip: str): super().__init__(agent_id, target_ip=target_ip) def validate(self, global_state) -> bool: - """Validates the pre-conditions for privilege escalation natively. + """Validates the pre-conditions for privilege escalation. - Args: - global_state (GlobalNetworkState): Snapshot of the network environment. - - Returns: - bool: True if escalation is physically feasible. + Requires: + - Valid routing to the target host. + - Prior User-level access on the target (must exploit first). """ - return True + host = global_state.all_hosts.get(self.target_ip) + if not host or host.privilege != 'User': + return False + return global_state.can_route_to(self.target_ip) def execute(self, global_state) -> ActionEffect: """Applies the mathematical delta to elevate the agent's privilege @@ -65,15 +66,13 @@ def __init__(self, agent_id: str, target_ip: str): super().__init__(agent_id, target_ip=target_ip) def validate(self, global_state) -> bool: - """Validates target compatibility (e.g., Windows OS assumption). - - Args: - global_state: Network state. - - Returns: - bool: True assuming the agent has obtained baseline 'User' access. - """ - return True + """Validates target compatibility: requires User access + Windows OS.""" + host = global_state.all_hosts.get(self.target_ip) + if not host or host.privilege != 'User': + return False + if 'Windows' not in host.os: + return False + return global_state.can_route_to(self.target_ip) def execute(self, global_state) -> ActionEffect: """Processes the DCOM impersonation attack delta. Fails if target OS is @@ -120,15 +119,13 @@ def __init__(self, agent_id: str, target_ip: str): super().__init__(agent_id, target_ip=target_ip) def validate(self, global_state) -> bool: - """Validates routing or baseline access requirements. - - Args: - global_state: Network state. - - Returns: - bool: Execution clearance boolean. - """ - return True + """Validates: requires User access + Linux OS + V4L2 vulnerability.""" + host = global_state.all_hosts.get(self.target_ip) + if not host or host.privilege != 'User': + return False + if 'Linux' not in host.os: + return False + return global_state.can_route_to(self.target_ip) def execute(self, global_state) -> ActionEffect: """Resolves the exploit outcome altering the target's privilege table. diff --git a/marl_cyborg/environment/parallel_env.py b/marl_cyborg/environment/parallel_env.py index 978c691..00c8363 100644 --- a/marl_cyborg/environment/parallel_env.py +++ b/marl_cyborg/environment/parallel_env.py @@ -236,7 +236,10 @@ def step( if not terminate[agent] and not truncate[agent] ] - return observations, rewards, terminate, truncate, {a: {} for a in self.agents} + # ── Build info dicts with security metrics for callbacks ── + infos = self._extract_agent_infos(observations, resolved_effects) + + return observations, rewards, terminate, truncate, infos def render(self): """Standard PettingZoo GUI logging render hook.""" @@ -409,3 +412,66 @@ def _calculate_reward( ) -> float: """Delegates reward logic directly to the localized Scenario module.""" return self.scenario.calculate_reward(agent_id, state, effect) + + def _extract_agent_infos(self, observations: dict, resolved_effects: dict) -> dict: + """Extracts security metrics for TensorBoard and CSV logging callbacks. + + Args: + observations: Dictionary of agent observations for this step. + resolved_effects: Dictionary of resolved action effects. + + Returns: + Dictionary mapping agent_id to an info dictionary with security metrics. + """ + infos = {} + for agent in list(observations.keys()): + agent_effect = resolved_effects.get(agent) + info: dict = {} + + # Count security-relevant events from this step + false_positives = 0 + successful_exploits = 0 + hosts_isolated = 0 + services_restored = 0 + + if agent_effect and agent_effect.success: + for delta_key, delta_val in agent_effect.state_deltas.items(): + if 'status' in delta_key and delta_val == 'isolated': + hosts_isolated += 1 + # Check if the isolated host was actually compromised + parts = delta_key.split('/') + if len(parts) >= 2: + ip = parts[1] + host = self.global_state.all_hosts.get(ip) + if host and host.compromised_by == 'None': + false_positives += 1 # Isolated a clean host + elif 'privilege' in delta_key and delta_val in ('User', 'Root'): + successful_exploits += 1 + elif 'status' in delta_key and delta_val == 'online': + services_restored += 1 + + info['false_positives'] = float(false_positives) + info['successful_exploits'] = float(successful_exploits) + info['hosts_isolated'] = float(hosts_isolated) + info['services_restored'] = float(services_restored) + + # Extra context for analysis + info['agent_energy'] = float(self.global_state.agent_energy.get(agent, 0)) + info['compromised_hosts'] = float( + sum( + 1 + for h in self.global_state.all_hosts.values() + if h.compromised_by != 'None' + ) + ) + info['isolated_hosts'] = float( + sum( + 1 + for h in self.global_state.all_hosts.values() + if h.status == 'isolated' + ) + ) + + infos[agent] = info + + return infos diff --git a/marl_cyborg/scenarios/ransomware.py b/marl_cyborg/scenarios/ransomware.py index fa7fe38..49b907a 100644 --- a/marl_cyborg/scenarios/ransomware.py +++ b/marl_cyborg/scenarios/ransomware.py @@ -10,9 +10,12 @@ class RansomwareScenario(BaseScenario): """ Scenario: Red Team aggressively deploys ransomware. - - Red is highly rewarded for system_integrity compromise (Impact actions). - - Red ignores stealth (virtually no noise penalty). - - Blue loses massive points for impacted systems and must isolate quickly. + Reward Design Principles: + ───────────────────────── + • Action-level rewards are given ONCE on the step the action succeeds. + • NO per-step state compounding (avoids runaway accumulation). + • Episode-end bonus/penalty based on final network state. + • Red and Blue reward magnitudes are balanced (~±5 per action). """ def __init__(self, agents): @@ -26,46 +29,135 @@ def calculate_reward( ) -> float: reward = 0.0 - if effect: - if getattr(effect, 'cost', 0) > 0: - reward -= effect.cost * 0.1 # Micro-penalty for expending energy + # Small energy cost penalty for any action + if effect and getattr(effect, 'cost', 0) > 0: + reward -= effect.cost * 0.05 - if 'red' in agent_id.lower() and effect.observation_data: - # Reward successful intelligence sharing inherently - if 'shared' in effect.observation_data: + if 'red' in agent_id.lower(): + reward += self._red_reward(agent_id, global_state, effect) + elif 'blue' in agent_id.lower(): + reward += self._blue_reward(agent_id, global_state, effect) + + return reward + + # ── Red Team Rewards ────────────────────────────────────── + + def _red_reward( + self, + agent_id: str, + global_state: 'GlobalNetworkState', + effect: 'ActionEffect' = None, + ) -> float: + reward = 0.0 + + if not effect: + return reward + + # ── ONE-TIME action bonuses (only on success) ───────── + if effect.success and effect.state_deltas: + for delta_key, delta_val in effect.state_deltas.items(): + # Initial compromise (None → User) + if 'privilege' in delta_key and delta_val == 'User': + reward += 3.0 + + # Privilege escalation (User → Root) + elif 'privilege' in delta_key and delta_val == 'Root': + reward += 5.0 + + # Ransomware impact (system_integrity → compromised) + elif 'system_integrity' in delta_key and delta_val == 'compromised': + reward += 10.0 + + # Setting compromised_by (confirms exploitation) + elif 'compromised_by' in delta_key and delta_val != 'None': reward += 2.0 - # Penalize falling for Decoys or EDR telemetry traps - if 'Failed against Decoy' in str(effect.observation_data.values()): - reward -= 5.0 - elif 'kernel panic' in str(effect.observation_data.values()): - reward -= 10.0 - red_impact_count = sum( + # ── Observation-based rewards ───────────────────────── + if effect.observation_data: + obs = effect.observation_data + + # Reconnaissance discoveries + if 'discovered_hosts' in obs: + discovered = obs['discovered_hosts'] + count = len(discovered) if isinstance(discovered, (list, set)) else 1 + reward += count * 0.5 + + if 'scan_results' in obs: + reward += 0.3 + + # Intelligence sharing + if 'shared' in obs: + reward += 1.0 + + # Penalties for failures + if 'Failed against Decoy' in str(obs.values()): + reward -= 3.0 + elif 'kernel panic' in str(obs.values()): + reward -= 5.0 + + # ── Failed action penalty ───────────────────────────── + if not effect.success: + reward -= 0.1 # Small penalty for wasted turn + + return reward + + # ── Blue Team Rewards ───────────────────────────────────── + + def _blue_reward( + self, + agent_id: str, + global_state: 'GlobalNetworkState', + effect: 'ActionEffect' = None, + ) -> float: + reward = 0.0 + + # ── ONE-TIME action bonuses ─────────────────────────── + if effect and effect.success and effect.state_deltas: + for delta_key, delta_val in effect.state_deltas.items(): + # Successful isolation + if 'status' in delta_key and delta_val == 'isolated': + ip = delta_key.split('/')[1] if '/' in delta_key else None + if ip: + host = global_state.all_hosts.get(ip) + if host and host.compromised_by != 'None': + reward += 5.0 # Correctly quarantined a compromised host + else: + reward -= 2.0 # False positive — isolated a clean host + + # Successful restoration/cleanup + elif 'privilege' in delta_key and delta_val == 'None': + reward += 3.0 # Successfully cleaned a host + + elif 'status' in delta_key and delta_val == 'online': + reward += 1.0 # Restored connectivity + + # Deployed a decoy + elif 'decoy' in delta_key and delta_val != 'inactive': + reward += 1.0 # Proactive defense + + # ── Per-step network health (small, bounded) ────────── + # Only count the RATIO of healthy hosts to avoid scale issues + total_hosts = max(len(global_state.all_hosts), 1) + healthy = sum( 1 for h in global_state.all_hosts.values() - if getattr(h, 'system_integrity', 'clean') == 'compromised' + if h.compromised_by == 'None' and h.status != 'isolated' ) + health_ratio = healthy / total_hosts + reward += health_ratio * 1.0 # 0.0 to 1.0 per step - if 'red' in agent_id.lower(): - # Massive reward for each encrypted host - reward += red_impact_count * 50.0 + # Small penalty proportional to compromised hosts + compromised = sum( + 1 + for h in global_state.all_hosts.values() + if h.compromised_by != 'None' and h.status != 'isolated' + ) + if compromised > 0: + reward -= (compromised / total_hosts) * 2.0 # 0.0 to -2.0 - elif 'blue' in agent_id.lower(): - for h in global_state.all_hosts.values(): - is_isolated = h.status == 'isolated' - is_compromised = ( - getattr(h, 'system_integrity', 'clean') == 'compromised' - ) - - if is_compromised and not is_isolated: - reward -= 50.0 # Massive penalty for active ransomware - elif is_isolated and getattr(h, 'privilege', 'None') in [ - 'User', - 'Root', - ]: - reward += 10.0 # Successfully quarantined an infected host - elif not is_compromised and not is_isolated: - reward += 1.0 # General uptime reward + # ── Failed action penalty ───────────────────────────── + if effect and not effect.success: + reward -= 0.1 return reward