-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_utils.py
More file actions
650 lines (474 loc) · 25 KB
/
train_utils.py
File metadata and controls
650 lines (474 loc) · 25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
from typing import Callable, Any, Sequence, Tuple
from enum import Enum
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import optax
import chex
Optimizer = Callable[[chex.ArrayTree, chex.ArrayTree], chex.ArrayTree] # (params, grads) -> params
class SimilarityMetric(str, Enum):
POLICY = "policy"
VALUE = "value"
POLICY_VALUE = "policy_value"
LEGAL_ACTIONS = "legal_actions"
LEGAL_POLICY_VALUE = "legal_policy_value"
ACTION_HISTORY_POLICY = "action_history_policy"
class DynamicsType(str, Enum):
ISET = "iset"
PUBLIC_STATE = "public_state"
class EntropySchedule:
"""Used from OpenSpiel. An increasing list of steps where the regularisation network is updated.
Example
EntropySchedule([3, 5, 10], [2, 4, 1])
=> [0, 3, 6, 11, 16, 21, 26, 36]
| 3 x2 | 5 x4 | 10 x1
"""
def __init__(self, *, sizes: Sequence[int], repeats: Sequence[int]):
"""Constructs a schedule of entropy iterations.
Args:
sizes: the list of iteration sizes.
repeats: the list, parallel to sizes, with the number of times for each
size from `sizes` to repeat.
"""
try:
if len(repeats) != len(sizes):
raise ValueError("`repeats` must be parallel to `sizes`.")
if not sizes:
raise ValueError("`sizes` and `repeats` must not be empty.")
if any([(repeat <= 0) for repeat in repeats]):
raise ValueError("All repeat values must be strictly positive")
if repeats[-1] != 1:
raise ValueError("The last value in `repeats` must be equal to 1, "
"ince the last iteration size is repeated forever.")
except ValueError as e:
raise ValueError(
f"Entropy iteration schedule: repeats ({repeats}) and sizes"
f" ({sizes})."
) from e
schedule = [0]
for size, repeat in zip(sizes, repeats):
schedule.extend([schedule[-1] + (i + 1) * size for i in range(repeat)])
self.schedule = np.array(schedule, dtype=np.int32)
def __call__(self, learner_step: int) -> Tuple[float, bool]:
"""Entropy scheduling parameters for a given `learner_step`.
Args:
learner_step: The current learning step.
Returns:
alpha: The mixing weight (from [0, 1]) of the previous policy with
the one before for computing the intrinsic reward.
update_target_net: A boolean indicator for updating the target network
with the current network.
"""
# The complexity below is because at some point we might go past
# the explicit schedule, and then we'd need to just use the last step
# in the schedule and apply the logic of
# ((learner_step - last_step) % last_iteration) == 0)
# The schedule might look like this:
# X----X-------X--X--X--X--------X
# learner_step | might be here ^ |
# or there ^ |
# or even past the schedule ^
# We need to deal with two cases below.
# Instead of going for the complicated conditional, let's just
# compute both and then do the A * s + B * (1 - s) with s being a bool
# selector between A and B.
# 1. assume learner_step is past the schedule,
# ie schedule[-1] <= learner_step.
last_size = self.schedule[-1] - self.schedule[-2]
last_start = self.schedule[-1] + (
learner_step - self.schedule[-1]) // last_size * last_size
# 2. assume learner_step is within the schedule.
start = jnp.amax(self.schedule * (self.schedule <= learner_step))
finish = jnp.amin(
self.schedule * (learner_step < self.schedule),
initial=self.schedule[-1],
where=(learner_step < self.schedule))
size = finish - start
# Now select between the two.
beyond = (self.schedule[-1] <= learner_step) # Are we past the schedule?
iteration_start = (last_start * beyond + start * (1 - beyond))
iteration_size = (last_size * beyond + size * (1 - beyond))
update_target_net = jnp.logical_and(
learner_step > 0,
jnp.sum(learner_step == iteration_start + iteration_size - 1),
)
alpha = jnp.minimum(
(2.0 * (learner_step - iteration_start)) / iteration_size, 1.0)
return alpha, update_target_net # pytype: disable=bad-return-type # jax-types
def masked_l2_loss(y_predicted, y_target, mask):
'''Computes the masked L2 loss with mask. It expects the shape of the inputs to be compatible'''
chex.assert_equal_rank((y_predicted, y_target, mask))
loss = ((lax.stop_gradient(y_target) - y_predicted) ** 2) * mask
return jnp.sum(loss)
def masked_l2_loss_with_normalization(y_predicted, y_target, mask, norm):
'''Computes the masked L2 loss with normalization. It expects the shape of the inputs (except norm) to be compatible'''
loss = masked_l2_loss(y_predicted, y_target, mask)
return loss / (norm + (norm == 0))
def optax_optimizer(
params: chex.ArrayTree,
init_and_update: optax.GradientTransformation) -> Optimizer:
"""Creates a parameterized function that represents an optimizer."""
init_fn, update_fn = init_and_update
@chex.dataclass
class OptaxOptimizer:
"""A jax-friendly representation of an optimizer state with the update."""
state: chex.Array
def __call__(self, params: chex.ArrayTree, grads: chex.ArrayTree) -> chex.ArrayTree:
updates, self.state = update_fn(grads, self.state, params) # pytype: disable=annotation-type-mismatch # numpy-scalars
return optax.apply_updates(params, updates)
return OptaxOptimizer(state=init_fn(params))
def init_params_optimizer(
network,
rng_key: chex.PRNGKey,
init_input,
optimizer_init: optax.GradientTransformation = optax.chain(optax.adamw(1e-3), optax.clip(100)),
):
params = network.init(rng_key, init_input)
optimizer = optax_optimizer(params, optimizer_init)
return params, optimizer
def init_network_with_optimizer(
network_class,
rng_key: chex.PRNGKey,
init_input,
optimizer_init: optax.GradientTransformation = optax.chain(optax.adamw(1e-3), optax.clip(100)),
network_args: tuple = (),
):
network = network_class(*network_args)
params, optimizer = init_params_optimizer(network, rng_key, init_input, optimizer_init)
return network, params, optimizer
def _policy_ratio(pi: chex.Array, mu: chex.Array, actions_oh: chex.Array, valid: chex.Array) -> chex.Array:
pi_actions_prob = jnp.sum(pi * actions_oh, axis=-1, keepdims=True) * valid + (1 - valid)
mu_actions_prob = jnp.sum(mu * actions_oh, axis=-1, keepdims=True) * valid + (1 - valid)
return pi_actions_prob / mu_actions_prob
def tree_where(pred: chex.Array, x: chex.ArrayTree, y: chex.ArrayTree) -> chex.ArrayTree:
"""Apply jnp.where to each leaf of a pytree."""
def _where(x, y):
return jnp.where(pred, x, y)
return jax.tree.map(_where, x, y)
def apply_force_with_threshold(decision_outputs: chex.Array, force: chex.Array,
threshold: float,
threshold_center: chex.Array) -> chex.Array:
"""Apply the force with below a given threshold."""
chex.assert_equal_shape((decision_outputs, force, threshold_center))
can_decrease = decision_outputs - threshold_center > -threshold
can_increase = decision_outputs - threshold_center < threshold
force_negative = jnp.minimum(force, 0.0)
force_positive = jnp.maximum(force, 0.0)
clipped_force = can_decrease * force_negative + can_increase * force_positive
return decision_outputs * lax.stop_gradient(clipped_force)
def neurd_loss(
logits: chex.Array,
policy: chex.Array,
q_values: chex.Array,
legal: chex.Array,
importance_sampling: chex.Array,
clip: float=10_000,
threshold: float=2.0
):
advantage = q_values - jnp.sum(policy * q_values, axis=-1, keepdims=True)
advantage = advantage * importance_sampling
advantage = lax.stop_gradient(jnp.clip(advantage, -clip, clip))
mean_logit = jnp.sum(logits * legal, axis=-1, keepdims=True) / jnp.sum(legal, axis=-1, keepdims=True)
logits_shifted = logits - mean_logit
threshold_ceter = jnp.zeros_like(logits_shifted)
neurd_loss_value = jnp.sum(legal * apply_force_with_threshold(logits_shifted, advantage, threshold, threshold_ceter), axis=-1, keepdims=True)
return neurd_loss_value
# TODO: Verify that merges the vectors corectly
def transform_trajectory_to_last_dimension(x: chex.Array) -> chex.Array:
return jnp.moveaxis(x, 0, -2).reshape((*x.shape[1:-1], -1))
def normalize_direction_with_mask(x:chex.Array, mask:chex.Array) -> chex.Array:
x = mask * x
#norm = jnp.linalg.norm(x, 2, -1, keepdims=True)
norm = jnp.sum(x ** 2, axis=-1, keepdims=True)
norm = norm + (norm < 1e-15)
norm = norm ** 0.5
ret = x / norm
return ret
def compute_soft_assignments(cluster_distance: chex.Array, temperature: float, cluster_closeness_assignment: float, repulsive_force: float):
'''Computes a soft-assignments to soft k-means (fuzzy c-means). This is not the original method. When more than 1 point is too close to the center, we only move the closest one.'''
closest = jnp.min(cluster_distance, -1, keepdims=True)
# nulled_clusters = jnp.where(jnp.logical_and(cluster_distance < cluster_closeness_assignment, cluster_distance > closest + 1e-10), 0, 1)
soft_assignments = jax.nn.softmax(-cluster_distance * temperature, axis=-1)
soft_assignments = jnp.where(jnp.logical_and(cluster_distance < cluster_closeness_assignment, cluster_distance > closest + 1e-10), -soft_assignments * repulsive_force, soft_assignments)
return soft_assignments
def compute_soft_hard_assignment(cluster_distance: chex.Array, temperature: float, cluster_closeness_assignment: float, repulsive_force: float, hard_k_means_closeness: float):
'''Computes a soft-assignments to soft k-means (fuzzy c-means). This is not the original method. When more than 1 point is too close to the center, we only move the closest one.'''
closest = jnp.min(cluster_distance, -1, keepdims=True)
# nulled_clusters = jnp.where(jnp.logical_and(cluster_distance < cluster_closeness_assignment, cluster_distance > closest + 1e-10), 0, 1)
soft_assignments = jax.nn.softmax(-cluster_distance * temperature, axis=-1)
soft_assignments = jnp.where(jnp.logical_and(cluster_distance < (cluster_closeness_assignment ** 2), cluster_distance > (closest + 1e-10)), -soft_assignments * repulsive_force, soft_assignments)
hard_assignments = (cluster_distance <= (closest + 1e-10)).astype(jnp.float32)
assignments = jnp.where(closest < (hard_k_means_closeness ** 2), hard_assignments, soft_assignments)
return lax.stop_gradient(assignments)
def compute_energy_repulsion(pred: chex.Array):
cluster_each_other_distance = pred[..., :, None, :] - pred[..., None, :, :]
cluster_each_other_distance = jnp.sum(cluster_each_other_distance ** 2, axis=-1)
exp_energy_repulsion = jnp.exp(-cluster_each_other_distance)
cluster_energy_repulsion = jnp.where(cluster_each_other_distance < 1e-8, 0, exp_energy_repulsion)
return jnp.mean(cluster_energy_repulsion)
def compute_energy_repulsion_inverse(pred: chex.Array):
cluster_each_other_distance = pred[..., :, None, :] - pred[..., None, :, :]
cluster_each_other_distance = jnp.sum(cluster_each_other_distance ** 2, axis=-1)
cluster_energy_repulsion = jnp.where(cluster_each_other_distance < 1e-8, 0, 1/(cluster_each_other_distance + 1e-9))
cluster_energy_repulsion = jnp.mean(cluster_energy_repulsion)
return cluster_energy_repulsion
def compute_separation_loss(pred: chex.Array, cluster_closeness: float = 1.0):
cluster_each_other_distance = pred[..., :, None, :] - pred[..., None, :, :]
cluster_each_other_distance = jnp.sum(cluster_each_other_distance ** 2, axis=-1)
separation_loss = jnp.maximum(0, cluster_closeness- cluster_each_other_distance)
return separation_loss
def pullback_loss(pred: chex.Array):
return jnp.sum(pred ** 2, axis=-1)
def compute_soft_kmeans_transformations(real:chex.Array, pred: chex.Array, valid: chex.Array, temperature: float, cluster_closeness_assignment: float, repulsive_force: float):
# The predicted dimension is missing
chex.assert_shape((real,), (*pred.shape[:-2], pred.shape[-1]))
cluster_difference = lax.stop_gradient(jnp.expand_dims(real, -2)) - pred
cluster_difference = cluster_difference * valid[..., None, None]
cluster_distance = jnp.sum(cluster_difference ** 2, axis=-1)
cluster_distance = cluster_distance + (cluster_distance < 1e-15)
cluster_distance = cluster_distance ** 0.5
# cluster_loss = jax.nn.logsumexp(cluster_distance, axis=-1)
cluster_soft_assignement = compute_soft_assignments(cluster_distance, temperature, cluster_closeness_assignment, repulsive_force)
# cluster_energy_repulsion = compute_energy_repulsion(pred)
# cluster_separation_loss = compute_separation_loss(pred, cluster_closeness_assignment) * 0.2
# cluster_pullback_loss = pullback_loss(pred) * 0.001
cluster_loss = jnp.mean(cluster_difference ** 2, axis=-1)
cluster_loss = jnp.sum(cluster_loss * cluster_soft_assignement, axis=-1) * valid
return jnp.mean(cluster_loss), cluster_soft_assignement
# TODO: This should take valid into account
def _compute_soft_kmeans_loss_with_cluster_assignments(real:chex.Array, pred: chex.Array, valid: chex.Array, temperature: float, cluster_closeness_assignment: float, repulsive_force: float, hard_k_means_closeness: float):
# The predicted dimension is missing
chex.assert_shape((real,), (*pred.shape[:-2], pred.shape[-1]))
cluster_difference = lax.stop_gradient(jnp.expand_dims(real, -2)) - pred
cluster_difference = cluster_difference * valid[..., None, None]
cluster_distance = jnp.sum(cluster_difference ** 2, axis=-1)
# cluster_distance = cluster_distance + (cluster_distance < 1e-15)
# cluster_distance = cluster_distance ** 0.5
# cluster_loss = jax.nn.logsumexp(cluster_distance, axis=-1)
cluster_soft_assignement = compute_soft_hard_assignment(cluster_distance, temperature, cluster_closeness_assignment, repulsive_force, hard_k_means_closeness)
# cluster_energy_repulsion = compute_energy_repulsion(pred) * 0.001
cluster_separation_loss = compute_separation_loss(pred, cluster_closeness_assignment) * 0.2
cluster_pullback_loss = pullback_loss(pred) * 0.0001
cluster_separation_loss = cluster_separation_loss * valid[..., None, None]
cluster_pullback_loss = cluster_pullback_loss * valid[..., None]
normalization = jnp.sum(valid)
cluster_separation_loss = cluster_separation_loss / (cluster_separation_loss.shape[-2] * cluster_separation_loss.shape[-1] * normalization + (normalization == 0))
cluster_pullback_loss = cluster_pullback_loss / (cluster_pullback_loss.shape[-1] * normalization + (normalization == 0))
cluster_separation_loss = jnp.sum(cluster_separation_loss)
cluster_pullback_loss = jnp.sum(cluster_pullback_loss)
cluster_loss = jnp.mean(cluster_difference ** 2, axis=-1)
cluster_loss = jnp.sum(cluster_loss * cluster_soft_assignement, axis=-1) * valid
cluster_loss = cluster_loss * jnp.sqrt(pred.shape[-2])
return jnp.mean(cluster_loss) + cluster_separation_loss + cluster_pullback_loss, cluster_soft_assignement
def _compute_soft_kmeans_loss_with_single(real: chex.Array, pred: chex.Array, probs: chex.Array, valid: chex.Array, temperature: float, cluster_closeness_assignment: float, repulsive_force: float, hard_k_means_closeness: float):
cluster_loss, cluster_soft_assignement = _compute_soft_kmeans_loss_with_cluster_assignments(real, pred, valid, temperature, cluster_closeness_assignment, repulsive_force, hard_k_means_closeness)
# cluster_soft_assignement = jnp.where(cluster_soft_assignement >= jnp.max(cluster_soft_assignement, -1, keepdims=True), 1, 0)
# prob_loss = optax.losses.softmax_cross_entropy(probs, jax.lax.stop_gradient(cluster_soft_assignement))
# cluster_hard_assignement = jnp.argmax(cluster_soft_assignement, axis=-1)
cluster_smooth_assignment = jnp.where(cluster_soft_assignement >= jnp.max(cluster_soft_assignement, -1, keepdims=True), 0.95, 0.05)
prob_loss = optax.softmax_cross_entropy(probs, cluster_smooth_assignment)
prob_loss = prob_loss * valid
return cluster_loss + jnp.mean(prob_loss)
def v_trace(
v: chex.Array,
valid: chex.Array,
sampling_policy: chex.Array,
network_policy: chex.Array,
regularization_term: chex.Array,
action_oh: chex.Array,
reward: chex.Array, # Still not regularized
lambda_: float = 1.0, # Lambda parameter for V-trace
c: float = 1.0, # Importance sampling clipping
rho: float = np.inf, # Importance sampling clipping
eta: float = 0.2, # Regularization factor
gamma: float = 1.0 # Discount factor
):
importance_sampling = _policy_ratio(network_policy, sampling_policy, action_oh, valid)
# The reason we use this is to ensure this is weighted by the amount of the times we sample it
inverted_sampling = _policy_ratio(jnp.ones_like(sampling_policy), sampling_policy, action_oh, valid)
regularization_entropy = eta * jnp.sum(network_policy * regularization_term, axis=-1)
weighted_regularization_term = -eta * regularization_term# + regularization_entropy[..., (1, 0), jnp.newaxis]
both_player_entropy = (regularization_entropy[..., 1] - regularization_entropy[..., 0])
entropy_reward = reward + both_player_entropy
entropy_reward = jnp.expand_dims(jnp.stack((entropy_reward, -entropy_reward), axis=-1), -1)
q_reward = jnp.stack((reward, -reward), axis=-1) + regularization_entropy[..., (1, 0)]
q_reward = jnp.expand_dims(q_reward, -1)
@chex.dataclass(frozen=True)
class VTraceCarry:
next_value: chex.Array # Network value in the next timestep
delta_v: chex.Array # Propagated delta V in V-trace from the next timestep
init_carry = VTraceCarry(
next_value=jnp.zeros_like(v[-1]),
delta_v=jnp.zeros_like(v[-1])
)
def _v_trace(carry: VTraceCarry, x) -> tuple[VTraceCarry, Any]:
(importance_sampling, v, q_reward, entropy_reward, weighted_regularization_term, valid, inverted_sampling, action_oh) = x
# reward_uncorrected = reward + gamma * carry.reward_uncorrected + entropy
# discounted_reward = reward + gamma * carry.reward
delta_v = jnp.minimum(rho, importance_sampling) * (entropy_reward + gamma * carry.next_value - v)
carry_delta_v = delta_v + lambda_ * jnp.minimum(c, importance_sampling) * gamma * carry.delta_v
v_target = v + carry_delta_v
# TODO: Shall we use opponent entropy reg term or entropy of played action?
# We use importance sampling of the opponent.
opponent_sampling = jnp.flip(importance_sampling, -2)
q_value = v + weighted_regularization_term + action_oh * opponent_sampling * inverted_sampling * (q_reward + gamma * (carry.next_value + carry.delta_v) - v )
# q_value = weighted_regularization_term + action_oh * opponent_sampling * inverted_sampling * (q_reward + gamma * (carry.next_value + carry.delta_v))
next_carry = VTraceCarry(
next_value=v,
delta_v=carry_delta_v
)
reset_carry = init_carry
reset_v_target = jnp.zeros_like(v_target)
reset_q_value = jnp.zeros_like(q_value)
reset_carry = init_carry
return tree_where(valid, (next_carry, (v_target, q_value)), (reset_carry, (reset_v_target, reset_q_value)))
# return jnp.where(valid, next_carry, reset_carry), (v_target, q_value)
_, (v_target, q_value) = lax.scan(
f=_v_trace,
init=init_carry,
xs=(importance_sampling, v, q_reward, entropy_reward, weighted_regularization_term, valid, inverted_sampling, action_oh),
reverse=True
)
return v_target, q_value
def retrace(
q: chex.Array,
valid: chex.Array,
sampling_policy: chex.Array,
network_policy: chex.Array,
action_oh: chex.Array,
reward: chex.Array, # Still not regularized
lambda_: float = 1.0, # Lambda parameter for V-trace
c: float = 1.0, # Importance sampling clipping
rho: float = np.inf, # Importance sampling clipping
gamma: float = 1.0 # Discount factor
):
importance_sampling = _policy_ratio(network_policy, sampling_policy, action_oh, valid)
# Clip importance sampling ratios
importance_sampling = jnp.minimum(importance_sampling, rho)
importance_sampling = jnp.minimum(importance_sampling, c)
@chex.dataclass(frozen=True)
class ReTraceCarry:
next_v: chex.Array # Policy * q in the next timestep
delta_q: chex.Array # Propagated delta V in V-trace from the next timestep
# Initialize carry for the scan
init_carry = ReTraceCarry (
next_v = jnp.zeros_like(q[-1]), # Initialize with zeros for the last timestep
delta_q = jnp.zeros_like(q[-1]) # Initialize delta Q with zeros
)
def _retrace(carry, x):
importance_sampling_t, q_t, reward_t, valid_t, action_oh_t = x
# TODO: We will compute next_v instead of next_q, since we have a policy in the next step, it is easy. We just need to replace the taken action with the target there.
# This is for only the action taken.
delta_q = jnp.minimum(rho, importance_sampling_t) * (reward_t + gamma * carry.next_v - q_t)
carry_delta_q = delta_q + lambda_ * jnp.minimum(c, importance_sampling_t) * gamma * carry.delta_q
# Compute target Q
q_target = q_t + carry_delta_q
# Those 2 should be equivalent
# next_q = action_oh_t * carry_delta_q + q_t
next_q = action_oh_t * q_target + (1 - action_oh_t) * q_t
next_v = jnp.sum(network_policy * next_q, axis=-1)
# Update carry for next iteration
next_carry = ReTraceCarry(
next_v = next_v,
delta_q = carry_delta_q
)
return next_carry, q_target
# Run the scan in reverse order
_, q_target = lax.scan(
f=_retrace,
init=init_carry,
xs=(importance_sampling, q, reward, valid, action_oh),
reverse=True
)
return q_target
def state_v_trace(
v: chex.Array,
sampling_policy: chex.Array,
transformed_policy: chex.Array,
actions_oh: chex.Array,
valid: chex.Array,
reward: chex.Array, # Still not regularized
lambda_: float = 1.0, # Lambda parameter for V-trace
c: float = 1.0, # Importance sampling clipping
rho: float = 1.0, # Importance sampling clipping
gamma: float = 1.0 # Discount factor
) -> chex.Array:
pi_action_prob = jnp.sum(transformed_policy * jnp.expand_dims(actions_oh, -3), axis=-1)
mu_action_prob = jnp.sum(sampling_policy * actions_oh, axis=-1)
importance_sampling = pi_action_prob / jnp.expand_dims(mu_action_prob, -2)
p1_is = importance_sampling[..., 0, None]
p2_is = jnp.expand_dims(importance_sampling[..., 1], -2)
@chex.dataclass(frozen=True)
class StateVTraceCarry:
"""The carry of the v-trace scan loop."""
next_state_value: chex.Array
next_state_delta_v: chex.Array
init_carry = StateVTraceCarry(
next_state_value=jnp.zeros_like(v[-1]),
next_state_delta_v=jnp.zeros_like(v[-1])
)
def _state_v_trace(carry: StateVTraceCarry, x) -> tuple[StateVTraceCarry, Any]:
(p1_is, p2_is, v, reward, valid) = x
delta_v = jnp.minimum(rho, p1_is) * jnp.minimum(rho, p2_is) * (reward + gamma * carry.next_state_value - v)
carry_delta_v = delta_v + lambda_ * jnp.minimum(c, p1_is) * jnp.minimum(c, p2_is) * gamma * carry.next_state_delta_v
v_target = v + carry_delta_v
reset_carry = init_carry
next_carry = StateVTraceCarry(
next_state_value=v,
next_state_delta_v=carry_delta_v
)
return tree_where(valid, (next_carry, v_target), (reset_carry, jnp.zeros_like(v_target)))
_, v_target = lax.scan(
f=_state_v_trace,
init=init_carry,
xs=(p1_is, p2_is, v, jnp.expand_dims(reward, (-1, -2)), jnp.expand_dims(valid, (-1, -2))),
reverse=True
)
return v_target
def expected_v_trace(
v: chex.Array,
valid: chex.Array,
sampling_policy: chex.Array,
network_policy: chex.Array,
regularization_term: chex.Array,
action_oh: chex.Array,
reward: chex.Array,
lambda_: float = 1.0,
c: float = 1.0,
rho: float = 1.0,
eta: float = 0.2,
gamma: float = 1.0
):
importance_sampling = _policy_ratio(network_policy, sampling_policy, action_oh, valid[..., jnp.newaxis])
regularization_entropy = eta * jnp.sum(network_policy * regularization_term, axis=-1)
both_player_entropy = regularization_entropy[..., 1] - regularization_entropy[..., 0]
entropy_reward = jnp.expand_dims(reward + both_player_entropy, -1)
@chex.dataclass(frozen=True)
class ExpectedVTraceCarry:
next_value: chex.Array
delta_v: chex.Array
init_carry = ExpectedVTraceCarry(
next_value=jnp.zeros_like(v[-1]),
delta_v=jnp.zeros_like(v[-1])
)
def _expected_v_trace(carry: ExpectedVTraceCarry, x) -> tuple[ExpectedVTraceCarry, Any]:
(importance_sampling, v, reward, valid) = x
rho_ = jnp.prod(jnp.minimum(rho, importance_sampling), -2)
c_ = jnp.prod(jnp.minimum(c, importance_sampling), -2)
delta_v = rho_ * (reward + gamma * carry.next_value - v)
carry_delta_v = delta_v + lambda_ * c_ * gamma * carry.delta_v
v_target = v + carry_delta_v
reset_carry = init_carry
next_carry = ExpectedVTraceCarry(
next_value=v,
delta_v=carry_delta_v
)
return tree_where(valid, (next_carry, v_target), (reset_carry, jnp.zeros_like(v_target)))
_, v_target = lax.scan(
f=_expected_v_trace,
init=init_carry,
xs=(importance_sampling, v, entropy_reward, valid),
reverse=True
)
return v_target