Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions marl_cyborg/actions/blue/mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def execute(self, global_state) -> ActionEffect:
"""
return ActionEffect(
success=True,
state_deltas={f'hosts/{self.target_ip}/status': 'online'},
observation_data={'alert': 'Host restored.'},
state_deltas={
f'hosts/{self.target_ip}/status': 'online',
f'hosts/{self.target_ip}/privilege': 'None',
f'hosts/{self.target_ip}/compromised_by': 'None',
},
observation_data={'alert': 'Host restored and cleaned.'},
)


Expand Down Expand Up @@ -125,7 +129,10 @@ def execute(self, global_state) -> ActionEffect:
"""
return ActionEffect(
success=True,
state_deltas={f'hosts/{self.target_ip}/privilege': 'None'},
state_deltas={
f'hosts/{self.target_ip}/privilege': 'None',
f'hosts/{self.target_ip}/compromised_by': 'None',
},
observation_data={'alert': 'Unauthorized access removed.'},
)

Expand Down
9 changes: 1 addition & 8 deletions marl_cyborg/actions/red/exploits.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,7 @@ def __init__(self, agent_id: str, target_ip: str):
super().__init__(agent_id, target_ip=target_ip)

def validate(self, global_state) -> bool:
"""Ascertains valid routing bounds to the simulated web interface.

Args:
global_state (GlobalNetworkState): Simulator reference parameter.

Returns:
bool: Availability check result.
"""
"""Requires valid routing to the web interface."""
return global_state.can_route_to(self.target_ip)

def execute(self, global_state) -> ActionEffect:
Expand Down
47 changes: 22 additions & 25 deletions marl_cyborg/actions/red/privilege_escalation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ def __init__(self, agent_id: str, target_ip: str):
super().__init__(agent_id, target_ip=target_ip)

def validate(self, global_state) -> bool:
"""Validates the pre-conditions for privilege escalation natively.
"""Validates the pre-conditions for privilege escalation.

Args:
global_state (GlobalNetworkState): Snapshot of the network environment.

Returns:
bool: True if escalation is physically feasible.
Requires:
- Valid routing to the target host.
- Prior User-level access on the target (must exploit first).
"""
return True
host = global_state.all_hosts.get(self.target_ip)
if not host or host.privilege != 'User':
return False
return global_state.can_route_to(self.target_ip)

def execute(self, global_state) -> ActionEffect:
"""Applies the mathematical delta to elevate the agent's privilege
Expand Down Expand Up @@ -65,15 +66,13 @@ def __init__(self, agent_id: str, target_ip: str):
super().__init__(agent_id, target_ip=target_ip)

def validate(self, global_state) -> bool:
"""Validates target compatibility (e.g., Windows OS assumption).

Args:
global_state: Network state.

Returns:
bool: True assuming the agent has obtained baseline 'User' access.
"""
return True
"""Validates target compatibility: requires User access + Windows OS."""
host = global_state.all_hosts.get(self.target_ip)
if not host or host.privilege != 'User':
return False
if 'Windows' not in host.os:
return False
return global_state.can_route_to(self.target_ip)

def execute(self, global_state) -> ActionEffect:
"""Processes the DCOM impersonation attack delta. Fails if target OS is
Expand Down Expand Up @@ -120,15 +119,13 @@ def __init__(self, agent_id: str, target_ip: str):
super().__init__(agent_id, target_ip=target_ip)

def validate(self, global_state) -> bool:
"""Validates routing or baseline access requirements.

Args:
global_state: Network state.

Returns:
bool: Execution clearance boolean.
"""
return True
"""Validates: requires User access + Linux OS + V4L2 vulnerability."""
host = global_state.all_hosts.get(self.target_ip)
if not host or host.privilege != 'User':
return False
if 'Linux' not in host.os:
return False
return global_state.can_route_to(self.target_ip)

def execute(self, global_state) -> ActionEffect:
"""Resolves the exploit outcome altering the target's privilege table.
Expand Down
68 changes: 67 additions & 1 deletion marl_cyborg/environment/parallel_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ def step(
if not terminate[agent] and not truncate[agent]
]

return observations, rewards, terminate, truncate, {a: {} for a in self.agents}
# ── Build info dicts with security metrics for callbacks ──
infos = self._extract_agent_infos(observations, resolved_effects)

return observations, rewards, terminate, truncate, infos

def render(self):
"""Standard PettingZoo GUI logging render hook."""
Expand Down Expand Up @@ -409,3 +412,66 @@ def _calculate_reward(
) -> float:
"""Delegates reward logic directly to the localized Scenario module."""
return self.scenario.calculate_reward(agent_id, state, effect)

def _extract_agent_infos(self, observations: dict, resolved_effects: dict) -> dict:
"""Extracts security metrics for TensorBoard and CSV logging callbacks.

Args:
observations: Dictionary of agent observations for this step.
resolved_effects: Dictionary of resolved action effects.

Returns:
Dictionary mapping agent_id to an info dictionary with security metrics.
"""
infos = {}
for agent in list(observations.keys()):
agent_effect = resolved_effects.get(agent)
info: dict = {}

# Count security-relevant events from this step
false_positives = 0
successful_exploits = 0
hosts_isolated = 0
services_restored = 0

if agent_effect and agent_effect.success:
for delta_key, delta_val in agent_effect.state_deltas.items():
if 'status' in delta_key and delta_val == 'isolated':
hosts_isolated += 1
# Check if the isolated host was actually compromised
parts = delta_key.split('/')
if len(parts) >= 2:
ip = parts[1]
host = self.global_state.all_hosts.get(ip)
if host and host.compromised_by == 'None':
false_positives += 1 # Isolated a clean host
elif 'privilege' in delta_key and delta_val in ('User', 'Root'):
successful_exploits += 1
elif 'status' in delta_key and delta_val == 'online':
services_restored += 1

info['false_positives'] = float(false_positives)
info['successful_exploits'] = float(successful_exploits)
info['hosts_isolated'] = float(hosts_isolated)
info['services_restored'] = float(services_restored)

# Extra context for analysis
info['agent_energy'] = float(self.global_state.agent_energy.get(agent, 0))
info['compromised_hosts'] = float(
sum(
1
for h in self.global_state.all_hosts.values()
if h.compromised_by != 'None'
)
)
info['isolated_hosts'] = float(
sum(
1
for h in self.global_state.all_hosts.values()
if h.status == 'isolated'
)
)

infos[agent] = info

return infos
162 changes: 127 additions & 35 deletions marl_cyborg/scenarios/ransomware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ class RansomwareScenario(BaseScenario):
"""
Scenario: Red Team aggressively deploys ransomware.

- Red is highly rewarded for system_integrity compromise (Impact actions).
- Red ignores stealth (virtually no noise penalty).
- Blue loses massive points for impacted systems and must isolate quickly.
Reward Design Principles:
─────────────────────────
• Action-level rewards are given ONCE on the step the action succeeds.
• NO per-step state compounding (avoids runaway accumulation).
• Episode-end bonus/penalty based on final network state.
• Red and Blue reward magnitudes are balanced (~±5 per action).
"""

def __init__(self, agents):
Expand All @@ -26,46 +29,135 @@ def calculate_reward(
) -> float:
reward = 0.0

if effect:
if getattr(effect, 'cost', 0) > 0:
reward -= effect.cost * 0.1 # Micro-penalty for expending energy
# Small energy cost penalty for any action
if effect and getattr(effect, 'cost', 0) > 0:
reward -= effect.cost * 0.05

if 'red' in agent_id.lower() and effect.observation_data:
# Reward successful intelligence sharing inherently
if 'shared' in effect.observation_data:
if 'red' in agent_id.lower():
reward += self._red_reward(agent_id, global_state, effect)
elif 'blue' in agent_id.lower():
reward += self._blue_reward(agent_id, global_state, effect)

return reward

# ── Red Team Rewards ──────────────────────────────────────

def _red_reward(
self,
agent_id: str,
global_state: 'GlobalNetworkState',
effect: 'ActionEffect' = None,
) -> float:
reward = 0.0

if not effect:
return reward

# ── ONE-TIME action bonuses (only on success) ─────────
if effect.success and effect.state_deltas:
for delta_key, delta_val in effect.state_deltas.items():
# Initial compromise (None → User)
if 'privilege' in delta_key and delta_val == 'User':
reward += 3.0

# Privilege escalation (User → Root)
elif 'privilege' in delta_key and delta_val == 'Root':
reward += 5.0

# Ransomware impact (system_integrity → compromised)
elif 'system_integrity' in delta_key and delta_val == 'compromised':
reward += 10.0

# Setting compromised_by (confirms exploitation)
elif 'compromised_by' in delta_key and delta_val != 'None':
reward += 2.0
# Penalize falling for Decoys or EDR telemetry traps
if 'Failed against Decoy' in str(effect.observation_data.values()):
reward -= 5.0
elif 'kernel panic' in str(effect.observation_data.values()):
reward -= 10.0

red_impact_count = sum(
# ── Observation-based rewards ─────────────────────────
if effect.observation_data:
obs = effect.observation_data

# Reconnaissance discoveries
if 'discovered_hosts' in obs:
discovered = obs['discovered_hosts']
count = len(discovered) if isinstance(discovered, (list, set)) else 1
reward += count * 0.5

if 'scan_results' in obs:
reward += 0.3

# Intelligence sharing
if 'shared' in obs:
reward += 1.0

# Penalties for failures
if 'Failed against Decoy' in str(obs.values()):
reward -= 3.0
elif 'kernel panic' in str(obs.values()):
reward -= 5.0

# ── Failed action penalty ─────────────────────────────
if not effect.success:
reward -= 0.1 # Small penalty for wasted turn

return reward

# ── Blue Team Rewards ─────────────────────────────────────

def _blue_reward(
self,
agent_id: str,
global_state: 'GlobalNetworkState',
effect: 'ActionEffect' = None,
) -> float:
reward = 0.0

# ── ONE-TIME action bonuses ───────────────────────────
if effect and effect.success and effect.state_deltas:
for delta_key, delta_val in effect.state_deltas.items():
# Successful isolation
if 'status' in delta_key and delta_val == 'isolated':
ip = delta_key.split('/')[1] if '/' in delta_key else None
if ip:
host = global_state.all_hosts.get(ip)
if host and host.compromised_by != 'None':
reward += 5.0 # Correctly quarantined a compromised host
else:
reward -= 2.0 # False positive — isolated a clean host

# Successful restoration/cleanup
elif 'privilege' in delta_key and delta_val == 'None':
reward += 3.0 # Successfully cleaned a host

elif 'status' in delta_key and delta_val == 'online':
reward += 1.0 # Restored connectivity

# Deployed a decoy
elif 'decoy' in delta_key and delta_val != 'inactive':
reward += 1.0 # Proactive defense

# ── Per-step network health (small, bounded) ──────────
# Only count the RATIO of healthy hosts to avoid scale issues
total_hosts = max(len(global_state.all_hosts), 1)
healthy = sum(
1
for h in global_state.all_hosts.values()
if getattr(h, 'system_integrity', 'clean') == 'compromised'
if h.compromised_by == 'None' and h.status != 'isolated'
)
health_ratio = healthy / total_hosts
reward += health_ratio * 1.0 # 0.0 to 1.0 per step

if 'red' in agent_id.lower():
# Massive reward for each encrypted host
reward += red_impact_count * 50.0
# Small penalty proportional to compromised hosts
compromised = sum(
1
for h in global_state.all_hosts.values()
if h.compromised_by != 'None' and h.status != 'isolated'
)
if compromised > 0:
reward -= (compromised / total_hosts) * 2.0 # 0.0 to -2.0

elif 'blue' in agent_id.lower():
for h in global_state.all_hosts.values():
is_isolated = h.status == 'isolated'
is_compromised = (
getattr(h, 'system_integrity', 'clean') == 'compromised'
)

if is_compromised and not is_isolated:
reward -= 50.0 # Massive penalty for active ransomware
elif is_isolated and getattr(h, 'privilege', 'None') in [
'User',
'Root',
]:
reward += 10.0 # Successfully quarantined an infected host
elif not is_compromised and not is_isolated:
reward += 1.0 # General uptime reward
# ── Failed action penalty ─────────────────────────────
if effect and not effect.success:
reward -= 0.1

return reward

Expand Down