-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlamis_train.py
More file actions
1920 lines (1483 loc) · 85.5 KB
/
lamis_train.py
File metadata and controls
1920 lines (1483 loc) · 85.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from games.jax_game import JaxGame, GameState
from lamis_networks import MAVSNetwork, SimilarityNetwork, LegalActionsNetwork, PublicStateEncoder, InfosetEncoder, PublicStateDynamicsNetwork, DynamicsNetwork, TransformationNetwork, ExpectedNetwork, RNaDNetwork, PublicStateDecoder
from train_utils import EntropySchedule, optax_optimizer, masked_l2_loss_with_normalization, _policy_ratio, neurd_loss, transform_trajectory_to_last_dimension, normalize_direction_with_mask, _compute_soft_kmeans_loss_with_single, state_v_trace, expected_v_trace, v_trace, compute_soft_kmeans_transformations
from typing import Sequence, Callable
import jax
import jax.numpy as jnp
import jax.lax as lax
import chex
import optax
from enum import Enum
import numpy as np
import pyspiel
import functools
# Taken from RNaD original
Params = chex.ArrayTree
Optimizer = Callable[[Params, Params], Params]
@chex.dataclass(frozen=True)
class TimeStep():
valid: chex.Array = () # [..., 1]
public_state: chex.Array = () # [..., PS]
obs: chex.Array = () # [..., Player, O]
legal: chex.Array = () # [..., Player, A]
action: chex.Array = () # [..., Player, A]
policy: chex.Array = () # [..., Player, A]
reward: chex.Array = () # [..., 1] Reward after playing an action
@chex.dataclass
class Optimizers:
rnad_optimizer: Optimizer = ()
rnad_optimizer_target: Optimizer = ()
expected_optimizer: Optimizer = ()
expected_optimizer_target: Optimizer = ()
mvs_optimizer: Optimizer = ()
mvs_optimizer_target: Optimizer = ()
transformation_opitimizer: Sequence[Optimizer] = ()
abstraction_optimizer: Sequence[Optimizer] = ()
ps_decoder_optimizer: Sequence[Optimizer] = ()
iset_encoder_optimizer: Sequence[Optimizer] = ()
similarity_optimizer: Sequence[Optimizer] = ()
legal_actions_optimizer: Sequence[Optimizer] = ()
dynamics_optimizer: Optimizer = ()
q_critic_optimizer: Optimizer = ()
q_critic_optimizer_target: Optimizer = ()
@chex.dataclass
class NetworkParameters:
rnad_params: Params = ()
rnad_params_target: Params = ()
rnad_params_prev: Params = ()
rnad_params_prev_: Params = ()
expected_params: Params = ()
expected_params_target: Params = ()
mvs_params: Params = ()
mvs_params_target: Params = ()
transformation_params: Sequence[Params] = ()
abstraction_params: Sequence[Params] = ()
ps_decoder_params: Sequence[Params] = ()
iset_encoder_params: Sequence[Params] = ()
similarity_params: Sequence[Params] = ()
legal_actions_params: Sequence[Params] = ()
dynamics_params: Params = ()
q_critic_params: Params = ()
q_critic_params_target: Params = ()
def similarity_policy(pi: chex.Array, scale: float = 2):
return (pi - 0.5) * scale
def similarity_value(v: chex.Array, scale: float = 1):
return v * scale
def similarity_legal(legal: chex.Array, scale: float = 2):
return (legal - 0.5) * scale
def similarity_action_history(action: chex.Array, scale: float = 1):
used_actions = jnp.tri(action.shape[0], action.shape[0] - 1, k=-1)
scaled_action = (action[None, :-1, ...] - 0.5) * scale
preceeding_actions = used_actions[..., None, None, None] * scaled_action
action_vector = jnp.moveaxis(preceeding_actions, 1, -2).reshape(*action.shape[:-1], -1)
return action_vector
def similarity_iset(iset: chex.Array, scale: float = 2):
return (iset - 0.5) * scale
class SimilarityMetric(str, Enum):
POLICY = "policy"
VALUE = "value"
POLICY_VALUE = "policy_value"
LEGAL_ACTIONS = "legal_actions"
LEGAL_POLICY = "legal_policy"
LEGAL_POLICY_VALUE = "legal_policy_value"
ACTION_HISTORY = "action_history"
ACTION_HISTORY_POLICY = "action_history_policy"
ACTION_HISTORY_LEGAL = "action_history_legal"
ACTION_HISTORY_LEGAL_POLICY = "action_history_legal_policy"
ISET_VECTOR = "iset_vector"
ISET_POLICY = "iset_policy"
class DynamicsType(str, Enum):
ISET = "iset"
PUBLIC_STATE = "public_state"
@chex.dataclass(frozen=True)
class LAMISTrainConfig:
batch_size: int = 32
trajectory_max: int = 6
sampling_epsilon: float = 0.0
train_rnad: bool = True
train_transformations: bool = True
train_mvs: bool = True
train_abstraction: bool = True
train_dynamics: bool = True
train_legal_actions: bool = True
use_abstraction: bool = False
abstraction_amount: int = 10
abstraction_size: int = 32
similarity_metric: SimilarityMetric = SimilarityMetric.POLICY_VALUE
similarity_noise: float = 0.02
abstraction_soft_k_means_temperature: float = 1.0
abstraction_soft_k_means_closeness_assignment: float = 0.5
abstraction_soft_k_means_repulsive_force: float = 3.0
abstraction_hard_k_means_closeness: float = 0.2
transformation_soft_k_means_temperature: float = 1.0
transformation_soft_k_means_closeness_assignment: float = 0.5
transformation_soft_k_means_repulsive_force: float = 3.0
dynamics_type: DynamicsType = DynamicsType.PUBLIC_STATE
ps_encoder_hidden_size: int = 256
ps_decoder_hidden_size: int = 256
iset_hidden_size: int = 256
dynamics_hidden_size: int = 256
similarity_hidden_size: int = 256
mvs_hidden_size: int = 256
legal_actions_hidden_size: int = 256
transformation_hidden_size: int = 256
rnad_hidden_size: int = 256
transformations: int = 10
matrix_valued_states: bool = True
c_iset_vtrace: float = 1.0
rho_iset_vtrace: float = np.inf
c_state_vtrace: float = 1.0
rho_state_vtrace: float = np.inf
eta_regularization: float = 0.2
entropy_schedule_repeats: Sequence[int] = (1,)
entropy_schedule_size: Sequence[int] = (1000,)
learning_rate: float = 3e-4
target_network_update: float = 1e-3
seed: int = 42
# This contains RNaD implementation. Note that this implementation is specific for two-player zero-sum games. Unlike the open_spiel RNaD that can be used to general-sum multiplayer games.
class LAMISTrain():
def __init__(self, game, config: LAMISTrainConfig) -> None:
assert config.matrix_valued_states, "Multi-valued states are not implemented."
self.config = config
self.game = game
if isinstance(self.game, JaxGame):
print("Warning: you use Jax game, so you need to use jax_step method")
self.init()
def init(self):
self.actions = self.game.num_distinct_actions()
if self.config.use_abstraction:
self.obs = self.config.abstraction_size
else:
self.obs = self.game.information_state_tensor_shape()
# self.rng_key = jax.random.PRNGKey(self.config.seed)
self.rng_key = jax.random.key(self.config.seed)
# temp_keys = self.get_next_rng_keys(6)
self.example_state = self.new_initial_state()
self.example_timestep = self.default_timestep()
self.example_obs = np.ones((self.obs))
self._entropy_schedule = EntropySchedule(
sizes=self.config.entropy_schedule_size,
repeats=self.config.entropy_schedule_repeats)
self.expected_network = ExpectedNetwork(self.config.rnad_hidden_size)
self.rnad_network = RNaDNetwork(self.config.rnad_hidden_size, self.actions)
self.abstraction_network = PublicStateEncoder(self.config.ps_encoder_hidden_size, self.config.abstraction_size, self.config.abstraction_amount)
self.ps_decoder = PublicStateDecoder(self.config.ps_decoder_hidden_size, self.game.public_state_tensor_shape())
self.iset_encoder = InfosetEncoder(self.config.iset_hidden_size, self.config.abstraction_amount)
self.similarity_network = SimilarityNetwork(self.config.similarity_hidden_size, self.similarity_output_size())
self.legal_actions_network = LegalActionsNetwork(self.config.legal_actions_hidden_size, self.actions)
if self.config.dynamics_type == DynamicsType.ISET:
self.dynamics_network = DynamicsNetwork(self.config.dynamics_hidden_size, self.obs)
elif self.config.dynamics_type == DynamicsType.PUBLIC_STATE:
assert self.config.use_abstraction == True, "Dynamics for Public state work only with abstrations."
self.dynamics_network = PublicStateDynamicsNetwork(self.config.dynamics_hidden_size, self.game.public_state_tensor_shape(), self.config.abstraction_amount)
self.transformation_network = TransformationNetwork(self.config.transformation_hidden_size, self.config.transformations, self.actions)
self.mvs_network = MAVSNetwork(self.config.mvs_hidden_size, self.config.transformations + 1)
self._rnad_loss = jax.value_and_grad(self.rnad_loss, has_aux=False) # Deprecate this?
self._abstraction_loss = jax.value_and_grad(self.abstraction_loss, argnums=[0,1,2,3], has_aux=False)
self._expected_loss = jax.value_and_grad(self.expected_loss, has_aux=False)
self._rnad_with_expected_loss = jax.value_and_grad(self.rnad_with_expected_loss, has_aux=False)
if self.config.use_abstraction:
if self.config.dynamics_type == DynamicsType.ISET:
self._dynamics_loss = jax.value_and_grad(self.abstracted_dynamics_loss, has_aux=False)
elif self.config.dynamics_type == DynamicsType.PUBLIC_STATE:
self._dynamics_loss = jax.value_and_grad(self.abstracted_ps_dynamics_loss, has_aux=False)
self._transformation_loss = jax.value_and_grad(self.abstracted_transformation_loss, has_aux=False)
self._mvs_loss = jax.value_and_grad(self.abstracted_mvs_loss, has_aux=False)
self._legal_actions_loss = jax.value_and_grad(self.abstracted_legal_actions_loss, has_aux=False)
else:
self._dynamics_loss = jax.value_and_grad(self.non_abstracted_dynamics_loss, has_aux=False)
self._transformation_loss = jax.value_and_grad(self.non_abstracted_transformation_loss, has_aux=False)
self._mvs_loss = jax.value_and_grad(self.non_abstracted_mvs_loss, has_aux=False)
self._legal_actions_loss = jax.value_and_grad(self.non_abstracted_legal_actions_loss, has_aux=False)
# temp_key = self.get_next_rng_key()
temp_keys = self.get_next_rng_keys(16)
params = self.rnad_network.init(temp_keys[0], self.example_timestep.obs, self.example_timestep.legal)
params_target = self.rnad_network.init(temp_keys[0], self.example_timestep.obs, self.example_timestep.legal)
params_prev = self.rnad_network.init(temp_keys[0], self.example_timestep.obs, self.example_timestep.legal)
params_prev_ = self.rnad_network.init(temp_keys[0], self.example_timestep.obs, self.example_timestep.legal)
optimizer = optax_optimizer(params, optax.chain(optax.adam(self.config.learning_rate, b1=0.0), optax.clip(100)))
optimizer_target = optax_optimizer(params_target, optax.sgd(self.config.target_network_update))
# TODO: Different init?
p1_abstraction_params = self.abstraction_network.init(temp_keys[1], self.example_timestep.public_state)
p2_abstraction_params = self.abstraction_network.init(temp_keys[2], self.example_timestep.public_state)
# TODO: Do we want 2 different networks for iset encoder and similarity?
p1_iset_encoder_params = self.iset_encoder.init(temp_keys[3], self.example_timestep.obs)
p2_iset_encoder_params = self.iset_encoder.init(temp_keys[4], self.example_timestep.obs)
p1_ps_decoder_params = self.ps_decoder.init(temp_keys[5], np.ones((1, self.config.abstraction_size)))
p2_ps_decoder_params = self.ps_decoder.init(temp_keys[6], np.ones((1, self.config.abstraction_size)))
# Similarity always uses abstraction
p1_similarity_params = self.similarity_network.init(temp_keys[7], np.ones((1, self.config.abstraction_size)))
p2_similarity_params = self.similarity_network.init(temp_keys[8], np.ones((1, self.config.abstraction_size)))
p1_legal_actions_params = self.legal_actions_network.init(temp_keys[9], self.example_obs)
p2_legal_actions_params = self.legal_actions_network.init(temp_keys[10], self.example_obs)
# self.dynamics_params = self.dynamics_network.init(temp_keys[6], self.example_timestep.obs, self.example_timestep.obs, self.example_timestep.action, self.example_timestep.action)
dynamics_params = self.dynamics_network.init(temp_keys[11], self.example_obs, self.example_obs, self.example_timestep.action, self.example_timestep.action)
mvs_params = self.mvs_network.init(temp_keys[12], self.example_obs, self.example_obs)
mvs_params_target = self.mvs_network.init(temp_keys[12], self.example_obs, self.example_obs)
p1_transformation_params = self.transformation_network.init(temp_keys[13], self.example_obs)
p2_transformation_params = self.transformation_network.init(temp_keys[14], self.example_obs)
expected_params = self.expected_network.init(temp_keys[15], self.example_timestep.obs, self.example_timestep.obs)
expected_params_target = self.expected_network.init(temp_keys[15], self.example_timestep.obs, self.example_timestep.obs)
p1_abstraction_optimizer = optax_optimizer(p1_abstraction_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p2_abstraction_optimizer = optax_optimizer(p2_abstraction_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p1_iset_encoder_optimizer = optax_optimizer(p1_iset_encoder_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p2_iset_encoder_optimizer = optax_optimizer(p2_iset_encoder_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p1_ps_decoder_optimizer = optax_optimizer(p1_ps_decoder_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p2_ps_decoder_optimizer = optax_optimizer(p2_ps_decoder_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p1_similarity_optimizer = optax_optimizer(p1_similarity_params, optax.chain(optax.adamw(self.config.learning_rate, weight_decay=1e-5), optax.clip(1)))
p2_similarity_optimizer = optax_optimizer(p2_similarity_params, optax.chain(optax.adamw(self.config.learning_rate, weight_decay=1e-5), optax.clip(1)))
p1_legal_actions_optimizer = optax_optimizer(p1_legal_actions_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
p2_legal_actions_optimizer = optax_optimizer(p2_legal_actions_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
dynamics_optimizer = optax_optimizer(dynamics_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
mvs_optimizer = optax_optimizer(mvs_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
mvs_optimizer_target = optax_optimizer(mvs_params_target, optax.sgd(self.config.target_network_update))
p1_transformation_optimizer = optax_optimizer(p1_transformation_params, optax.chain(optax.adamw(self.config.learning_rate, weight_decay=0.0), optax.clip(1)))
p2_transformation_optimizer = optax_optimizer(p2_transformation_params, optax.chain(optax.adamw(self.config.learning_rate, weight_decay=0.0), optax.clip(1)))
expected_optimizer = optax_optimizer(expected_params, optax.chain(optax.adam(self.config.learning_rate), optax.clip(100)))
expected_optimizer_target = optax_optimizer(expected_params_target, optax.sgd(self.config.target_network_update))
self.optimizers = Optimizers(
rnad_optimizer = optimizer,
rnad_optimizer_target = optimizer_target,
expected_optimizer = expected_optimizer,
expected_optimizer_target = expected_optimizer_target,
mvs_optimizer = mvs_optimizer,
mvs_optimizer_target = mvs_optimizer_target,
transformation_opitimizer = (p1_transformation_optimizer, p2_transformation_optimizer),
abstraction_optimizer = (p1_abstraction_optimizer, p2_abstraction_optimizer),
ps_decoder_optimizer= (p1_ps_decoder_optimizer, p2_ps_decoder_optimizer),
iset_encoder_optimizer = (p1_iset_encoder_optimizer, p2_iset_encoder_optimizer),
similarity_optimizer = (p1_similarity_optimizer, p2_similarity_optimizer),
legal_actions_optimizer= (p1_legal_actions_optimizer, p2_legal_actions_optimizer),
dynamics_optimizer = dynamics_optimizer
)
self.network_parameters = NetworkParameters(
rnad_params = params,
rnad_params_target = params_target,
rnad_params_prev = params_prev,
rnad_params_prev_ = params_prev_,
expected_params = expected_params,
expected_params_target = expected_params_target,
mvs_params = mvs_params,
mvs_params_target = mvs_params_target,
transformation_params = (p1_transformation_params, p2_transformation_params),
abstraction_params = (p1_abstraction_params, p2_abstraction_params),
ps_decoder_params= (p1_ps_decoder_params, p2_ps_decoder_params),
iset_encoder_params = (p1_iset_encoder_params, p2_iset_encoder_params),
similarity_params = (p1_similarity_params, p1_similarity_params),
legal_actions_params= (p1_legal_actions_params, p2_legal_actions_params),
dynamics_params = dynamics_params
)
self.learner_steps = 0
def similarity_output_size(self):
action_history_size = self.actions * (self.config.trajectory_max - 1)
if self.config.similarity_metric == SimilarityMetric.POLICY:
return self.actions
elif self.config.similarity_metric == SimilarityMetric.VALUE:
return 1
elif self.config.similarity_metric == SimilarityMetric.POLICY_VALUE:
return self.actions + 1
elif self.config.similarity_metric == SimilarityMetric.LEGAL_ACTIONS:
return self.actions
elif self.config.similarity_metric == SimilarityMetric.LEGAL_POLICY:
return 2 * self.actions
elif self.config.similarity_metric == SimilarityMetric.LEGAL_POLICY_VALUE:
return 2 * self.actions + 1
elif self.config.similarity_metric == SimilarityMetric.ACTION_HISTORY:
return action_history_size
elif self.config.similarity_metric == SimilarityMetric.ACTION_HISTORY_POLICY:
return action_history_size + self.actions
elif self.config.similarity_metric == SimilarityMetric.ACTION_HISTORY_LEGAL:
return action_history_size + self.actions
elif self.config.similarity_metric == SimilarityMetric.ACTION_HISTORY_LEGAL_POLICY:
return action_history_size + 2 * self.actions
elif self.config.similarity_metric == SimilarityMetric.ISET_VECTOR:
return self.game.information_state_tensor_shape()
elif self.config.similarity_metric == SimilarityMetric.ISET_POLICY:
return self.actions + self.game.information_state_tensor_shape()
assert False, "Unknown similarity metric"
def default_timestep(self):
obs = np.zeros((2, self.game.information_state_tensor_shape()), dtype=np.float32)
public_state = np.zeros(self.game.public_state_tensor_shape(), dtype=np.float32)
legal = np.ones((2, self.actions), dtype=np.int8)
action = np.ones(self.actions, dtype=np.float32)
policy = np.ones(self.actions, dtype=np.float32)
valid = np.array([0], dtype=np.float32)
reward = np.array([0], dtype=np.float32)
ts = TimeStep(
valid = valid,
public_state = public_state,
obs = obs,
legal = legal,
action = action,
policy = policy,
reward = reward
)
# return ts
return ts
def new_initial_state(self):
if isinstance(self.game, JaxGame):
start_key = self.get_next_rng_key()
return self.game.new_initial_state(start_key)
return self.game.new_initial_state()
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_rnad_network(self, params, obs, legal) -> chex.Array:
return self.rnad_network.apply(params, obs, legal)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_policy(self, params, obs, legal) -> chex.Array:
return self._jit_get_rnad_network(params, obs, legal)[0]
# TODO: Be careful, this sometimes produces an action that is illegal
@functools.partial(jax.jit, static_argnums=(0, ))
def _jit_sample_action(self, key, pi: chex.Array):
def choice_wrapper(key, pi):
return jax.random.choice(key, self.actions, p=pi)
action = jax.vmap(choice_wrapper, in_axes=(0, 0), out_axes=0)(key, pi)
action_oh = jax.nn.one_hot(action, self.actions)
return action, action_oh
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_policy_and_action(self, params, key, obs, legal) -> chex.Array:
pi = self._jit_get_policy(params, obs, legal)
action, action_oh = self._jit_sample_action(key, pi)
return pi, action, action_oh
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_batch_policy(self, params, keys, obs, legal) -> chex.Array:
return jax.vmap(self._jit_get_policy_and_action, in_axes=(None, 1, 1, 1), out_axes=1)(params, keys, obs, legal)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_legal_actions(self, legal_actions_params, obs) -> chex.Array:
return self.legal_actions_network.apply(legal_actions_params, obs)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_next_state(self, params, p1_iset, p2_iset, p1_action, p2_action):
p1_action = jax.nn.one_hot(p1_action, self.actions)
p2_action = jax.nn.one_hot(p2_action, self.actions)
return self.dynamics_network.apply(params, p1_iset, p2_iset, p1_action, p2_action)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_next_state_ps(self, dynamics_params, p1_abstraction_params, p2_abstraction_params, p1_iset, p2_iset, p1_action, p2_action):
next_ps, next_p1_dist, next_p2_dist, reward, terminal = self._jit_get_next_state(dynamics_params, p1_iset, p2_iset, p1_action, p2_action)
next_ps = jnp.where(next_ps > 0.5, 1, 0)
next_p1_isets = self._jit_get_all_abstractions(p1_abstraction_params, next_ps)
next_p2_isets = self._jit_get_all_abstractions(p2_abstraction_params, next_ps)
next_p1_iset = jnp.argmax(next_p1_dist, axis=-1, keepdims=True)
next_p2_iset = jnp.argmax(next_p2_dist, axis=-1, keepdims=True)
return jnp.squeeze(jnp.take_along_axis(next_p1_isets, next_p1_iset[..., jnp.newaxis], axis=-2), -2), jnp.squeeze(jnp.take_along_axis(next_p2_isets, next_p2_iset[..., jnp.newaxis], axis=-2), -2), reward, terminal
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_all_abstractions(self, abstraction_params, public_state):
return self.abstraction_network.apply(abstraction_params, public_state)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_iset_probabilities(self, iset_encoder_params, obs):
return self.iset_encoder.apply(iset_encoder_params, obs)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_similarity(self, similarity_params, obs):
return self.similarity_network.apply(similarity_params, obs)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_abstraction(self,abstraction_params, iset_params, public_state, obs):
abstraction = self.abstraction_network.apply(abstraction_params, public_state)
iset = self.iset_encoder.apply(iset_params, obs)
picked_iset = jnp.argmax(iset, axis=-1, keepdims=True)
return jnp.squeeze(jnp.take_along_axis(abstraction, picked_iset[..., jnp.newaxis], axis=-2), -2)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_full_abstraction(self, abstraction_params, public_state):
return self.abstraction_network.apply(abstraction_params, public_state)
@functools.partial(jax.jit, static_argnums=(0,))
def _jit_get_abstraction_with_iset_id(self,abstraction_params, iset_params, public_state, obs):
abstraction = self.abstraction_network.apply(abstraction_params, public_state)
iset = self.iset_encoder.apply(iset_params, obs)
picked_iset = jnp.argmax(iset, axis=-1, keepdims=True)
return picked_iset, jnp.squeeze(jnp.take_along_axis(abstraction, picked_iset[..., jnp.newaxis], axis=-2), -2)
@functools.partial(jax.jit, static_argnums=(0, ))
def _jit_get_mvs(self, mvs_params, p1_iset, p2_iset):
return self.mvs_network.apply(mvs_params, p1_iset, p2_iset)
@functools.partial(jax.jit, static_argnums=(0, ))
def _jit_get_decoded_public_state(self, ps_decoder_params, obs):
return self.ps_decoder.apply(ps_decoder_params, obs)
# The observaiton is already only for a given player pl
def get_abstraction(self, public_state, obs, pl):
if not self.config.use_abstraction:
return obs
return self._jit_get_abstraction(self.network_parameters.abstraction_params[pl], self.network_parameters.iset_encoder_params[pl], public_state, obs)
def get_both_abstraction(self, public_state, p1_iset, p2_iset):
if not self.config.use_abstraction:
return p1_iset, p2_iset
p1_abstraction_iset = self._jit_get_abstraction(self.network_parameters.abstraction_params[0], self.network_parameters.iset_encoder_params[0], public_state, p1_iset)
p2_abstraction_iset = self._jit_get_abstraction(self.network_parameters.abstraction_params[1], self.network_parameters.iset_encoder_params[1], public_state, p2_iset)
return p1_abstraction_iset, p2_abstraction_iset
def get_both_full_abstraction(self, public_state):
if not self.config.use_abstraction:
return jnp.ones((2, 1))
p1_abstraction_distribution = self._jit_get_full_abstraction(self.network_parameters.abstraction_params[0], public_state)
p2_abstraction_distribution = self._jit_get_full_abstraction(self.network_parameters.abstraction_params[1], public_state)
return p1_abstraction_distribution, p2_abstraction_distribution
def get_both_iset_probabilities(self, p1_iset, p2_iset):
p1_abstraction_distribution = self._jit_get_iset_probabilities(self.network_parameters.iset_encoder_params[0], p1_iset)
p2_abstraction_distribution = self._jit_get_iset_probabilities(self.network_parameters.iset_encoder_params[1], p2_iset)
p1_abstraction_distribution = jax.nn.softmax(p1_abstraction_distribution, axis=-1)
p2_abstraction_distribution = jax.nn.softmax(p2_abstraction_distribution, axis=-1)
return p1_abstraction_distribution, p2_abstraction_distribution
def get_decoded_public_state(self, obs, pl):
return self._jit_get_decoded_public_state(self.network_parameters.ps_decoder_params[pl], obs)
def get_next_state_from_abstraction(self, p1_iset, p2_iset, p1_action, p2_action):
if self.config.dynamics_type == DynamicsType.ISET:
return self._jit_get_next_state(self.network_parameters.dynamics_params, p1_iset, p2_iset, p1_action, p2_action)
elif self.config.dynamics_type == DynamicsType.PUBLIC_STATE:
return self._jit_get_next_state_ps(self.network_parameters.dynamics_params, self.network_parameters.abstraction_params[0], self.network_parameters.abstraction_params[1], p1_iset, p2_iset, p1_action, p2_action)
assert False, "Wrong dynamics type"
# Expects isets in the original game definition and action as a index of the action
def get_next_state(self, public_state, p1_iset, p2_iset, p1_action, p2_action):
if self.config.use_abstraction:
p1_iset, p2_iset = self.get_both_abstraction(public_state, p1_iset, p2_iset)
return self.get_next_state_from_abstraction(p1_iset, p2_iset, p1_action, p2_action)
def get_legal_actions(self, public_state, obs, pl):
if self.config.use_abstraction:
obs = self.get_abstraction(public_state, obs, pl)
return self._jit_get_legal_actions(self.network_parameters.legal_actions_params[pl], obs)
def get_both_legal_actions_from_abstraction(self, p1_iset, p2_iset):
p1_legal = self._jit_get_legal_actions(self.network_parameters.legal_actions_params[0], p1_iset)
p2_legal = self._jit_get_legal_actions(self.network_parameters.legal_actions_params[1], p2_iset)
return p1_legal, p2_legal
def get_both_legal_actions(self, public_state, p1_iset, p2_iset):
if self.config.use_abstraction:
p1_iset, p2_iset = self.get_both_abstraction(public_state, p1_iset, p2_iset)
p1_legal = self._jit_get_legal_actions(self.network_parameters.legal_actions_params[0], p1_iset)
p2_legal = self._jit_get_legal_actions(self.network_parameters.legal_actions_params[1], p2_iset)
return p1_legal, p2_legal
def get_mvs_from_abstraction(self, p1_iset, p2_iset):
return self._jit_get_mvs(self.network_parameters.mvs_params_target, p1_iset, p2_iset)
def get_mvs(self, public_state, p1_iset, p2_iset):
if self.config.use_abstraction:
p1_iset, p2_iset = self.get_both_abstraction(public_state, p1_iset, p2_iset)
return self._jit_get_mvs(self.network_parameters.mvs_params_target, p1_iset, p2_iset)
def get_policy(self, state: pyspiel.State, player: int):
obs = state.information_state_tensor(player)
legal = state.legal_actions_mask(player)
pi = self._jit_get_policy(self.network_parameters.rnad_params, obs, legal)
return np.array(pi, dtype=np.float32)
def get_policy_both(self, state: pyspiel.State):
obs = [state.information_state_tensor(pl) for pl in range(2)]
legal = [state.legal_actions_mask(pl) for pl in range(2)]
obs = np.array(obs, dtype=np.float32)
legal = np.array(legal, dtype=np.int8)
pi = self._jit_get_policy(self.network_parameters.rnad_params, obs, legal)
pi = np.array(pi, dtype=np.float64)
return pi[0], pi[1]
def get_policy_and_value(self, state: pyspiel.State, player: int):
obs = state.information_state_tensor(player)
legal = state.legal_actions_mask(player)
pi, v, _, _ = self._jit_get_rnad_network(self.network_parameters.rnad_params, obs, legal)
return np.array(pi, dtype=np.float64), np.array(v, dtype=np.float64)
def get_policy_and_value_both(self, obs, legal):
pi, v, _, _ = self._jit_get_rnad_network(self.network_parameters.rnad_params, obs, legal)
return pi[0], pi[1], v[0], v[1]
def get_policy_and_value_from_state_both(self, state: pyspiel.State):
obs = [state.information_state_tensor(pl) for pl in range(2)]
legal = [state.legal_actions_mask(pl) for pl in range(2)]
obs = np.array(obs, dtype=np.float32)
legal = np.array(legal, dtype=np.int8)
pi, v, _, _ = self._jit_get_rnad_network(self.network_parameters.rnad_params, obs, legal)
pi = np.array(pi, dtype=np.float64)
v = np.array(v, dtype=np.float64)
return pi[0], pi[1], v[0], v[1]
def get_both_similarities_and_probs(self, public_state: chex.Array, p1_iset: chex.Array, p2_iset: chex.Array):
p1_abstractions = self._jit_get_all_abstractions(self.network_parameters.abstraction_params[0], public_state)
p2_abstractions = self._jit_get_all_abstractions(self.network_parameters.abstraction_params[1], public_state)
p1_probs = self._jit_get_iset_probabilities(self.network_parameters.iset_encoder_params[0], p1_iset)
p2_probs = self._jit_get_iset_probabilities(self.network_parameters.iset_encoder_params[1], p2_iset)
p1_similarities = self._jit_get_similarity(self.network_parameters.similarity_params[0], p1_abstractions)
p2_similarities = self._jit_get_similarity(self.network_parameters.similarity_params[1], p2_abstractions)
return p1_abstractions, p2_abstractions, p1_probs, p2_probs, p1_similarities, p2_similarities
# TODO: Improve this
# Expects obs and legal to be in shape [Batch, Player, ...]
def batch_policy_and_action(self, obs, legal):
keys = self.get_next_rng_keys_dimensional(obs.shape[:2])
keys = np.array(keys)
pi, action, action_oh = self._jit_get_batch_policy(self.network_parameters.rnad_params, keys, obs, legal)
# pi, action, action_oh = self._jit_get_policy_and_action(self.params, keys, obs, legal)
pi = np.array(pi, dtype=np.float64)
pi = pi / np.sum(pi, axis=-1, keepdims=True) # TODO: Remove this
action = np.array(action, dtype=np.int32)
action_oh = np.array(action_oh, dtype=np.float64)
return pi, action, action_oh
def _batch_states_as_timestep(self, states: Sequence[pyspiel.State]) -> TimeStep:
reward = []
p1_obs = []
p2_obs = []
p1_legal = []
p2_legal = []
valid = []
for state in states:
if state.is_terminal():
p1_obs.append(self.example_state.information_state_tensor(0))
p2_obs.append(self.example_state.information_state_tensor(1))
p1_legal.append(self.example_state.legal_actions_mask(0))
p2_legal.append(self.example_state.legal_actions_mask(1))
valid.append(0)
else:
p1_obs.append(state.information_state_tensor(0))
p2_obs.append(state.information_state_tensor(1))
p1_legal.append(state.legal_actions_mask(0))
p2_legal.append(state.legal_actions_mask(1))
valid.append(1)
obs = np.stack((p1_obs, p2_obs), axis=1, dtype=np.float32)
legal = np.stack((p1_legal, p2_legal), axis=1, dtype=np.int8)
# p1_obs = np.array(p1_obs, dtype=np.float32)
# p2_obs = np.array(p2_obs, dtype=np.float32)
# p1_legal = np.array(p1_legal, dtype=np.int8)
# p2_legal = np.array(p2_legal, dtype=np.int8)
valid = np.array(valid, dtype=np.float32)
# obs = np.concatenate((p1_obs, p2_obs), axis=0)
# legal = np.concatenate((p1_legal, p2_legal), axis=0)
public_state = np.array([state.public_state_tensor() for state in states], dtype=np.float32)
pi, action, action_oh = self.batch_policy_and_action(obs, legal)
for i, state in enumerate(states):
if state.is_terminal():
reward.append(0)
continue
if action[i][0] not in state.legal_actions(0) or action[i][1] not in state.legal_actions(1):
raise ValueError("Illegal action")
state.apply_actions(action[i])
reward.append(state.returns()[0])
reward = np.array(reward, dtype=np.float32)
return TimeStep(
valid = valid,
public_state = public_state,
obs = obs,
legal = legal,
action = action_oh,
policy = pi,
reward = reward
)
# No chance in the game!
def sample_trajectories(self) -> TimeStep:
states = [self.game.new_initial_state() for _ in range(self.config.batch_size)]
timesteps = []
for _ in range(self.config.trajectory_max):
# list of states is passed as a reference to the list! So updates in function takes place in the original list
timesteps.append(self._batch_states_as_timestep(states))
return jax.tree.map(lambda *xs: np.stack(xs, axis=0), *timesteps)
def sample_trajectory(self, params, key) -> TimeStep:
init_key, trajectory_key, = jax.random.split(key)
trajectory_key = jax.random.split(trajectory_key, self.config.trajectory_max)
max_turns = self.config.trajectory_max
actions = self.actions
@chex.dataclass(frozen=True)
class SampleTrajectoryCarry:
game_state: GameState
terminal: bool
legal_actions: chex.Array
game_state, legal_actions = self.game.initialize_structures(init_key)
init_carry = SampleTrajectoryCarry(
game_state = game_state,
terminal = False,
legal_actions = legal_actions
)
@jax.jit
def choice_wrapper(key, p):
action = jax.random.choice(key, actions, p=p)
action_oh = jax.nn.one_hot(action, actions)
return action, action_oh
vectorized_sample_action = jax.vmap(choice_wrapper, in_axes=(0, 0), out_axes=0)
def _sample_trajectory(carry: SampleTrajectoryCarry, xs) -> tuple[SampleTrajectoryCarry, chex.Array]:
(key, turn) = xs
_, p1_iset, p2_iset, public_state = self.game.get_info(carry.game_state)
obs = jnp.stack((p1_iset, p2_iset), axis=0)
public_state = jnp.where(carry.terminal, self.example_timestep.public_state, public_state)
obs = jnp.where(carry.terminal, self.example_timestep.obs, obs)
pi = self._jit_get_policy(params, obs, carry.legal_actions)
random_pi = carry.legal_actions / jnp.sum(carry.legal_actions, axis=-1, keepdims=True)
pi = self.config.sampling_epsilon * random_pi + (1 - self.config.sampling_epsilon) * pi
sample_key, action_key = jax.random.split(key)
# For each player samples a single action
sample_key = jax.random.split(sample_key, 2)
action, action_oh = vectorized_sample_action(sample_key, pi)
next_game_state, terminal, next_rewards, next_legal = self.game.apply_action(carry.game_state, action_key, turn, action)
valid = jnp.ones_like(next_rewards) - carry.terminal
terminal = jnp.logical_or(terminal, carry.terminal)
#TODO: This can likely be done better, couldnt get tree_where to work
#timestep_legal = jnp.where(valid[..., None, None], carry.legal_actions, self.example_timestep.legal)
next_rewards = jnp.where(valid, next_rewards, 0)
new_carry = SampleTrajectoryCarry(
game_state = next_game_state,
terminal = terminal,
legal_actions=jnp.where(terminal, self.example_timestep.legal, next_legal)
)
timestep = TimeStep(
valid = valid,
public_state = public_state,
obs = obs,
legal = carry.legal_actions,
action = action_oh,
policy = pi,
reward = next_rewards
)
return new_carry, timestep
_, timestep = lax.scan(_sample_trajectory,
init=init_carry,
xs=(trajectory_key, jnp.arange(max_turns)))
return timestep
def get_next_rng_key(self):
self.rng_key, key = jax.random.split(self.rng_key)
return key
def get_next_rng_keys(self, n):
self.rng_key, *keys = jax.random.split(self.rng_key, n+1)
return keys
# First it generates keys for the batch
def get_next_rng_keys_dimensional(self, n):
key = self.get_next_rng_key()
keys = jax.random.split(key, n)
return keys
def rnad_loss(
self,
params: Params,
params_target: Params,
params_prev: Params,
params_prev_: Params,
timestep: TimeStep,
alpha: float,
):
# We map over trajectory dimension and player dimension
vectorized_net_apply = jax.vmap(jax.vmap(self.rnad_network.apply, in_axes=(None, 0, 0), out_axes=0), in_axes=(None, -2, -2), out_axes=-2)
pi, v, log_pi, logit = vectorized_net_apply(params, timestep.obs, timestep.legal)
_, v_target, _, _ = vectorized_net_apply(params_target, timestep.obs, timestep.legal)
_, _, log_pi_prev, _ = vectorized_net_apply(params_prev, timestep.obs, timestep.legal)
_, _, log_pi_prev_, _ = vectorized_net_apply(params_prev_, timestep.obs, timestep.legal)
# This creates the regularization term for rewards
regularized_term = log_pi - (alpha * log_pi_prev + (1 - alpha) * log_pi_prev_)
expanded_valid = jnp.expand_dims(timestep.valid, (-2, -1))
v_train_target, q_value = v_trace(v_target, expanded_valid, timestep.policy, pi, regularized_term, timestep.action, timestep.reward, c=self.config.c_iset_vtrace, rho=self.config.rho_iset_vtrace, eta=self.config.eta_regularization)
# We multiply by 2, since each player acts
normalization = jnp.sum(timestep.valid) * 2
v_loss = jnp.sum((expanded_valid * (v - lax.stop_gradient(v_train_target)) ** 2)) / (normalization + (normalization == 0))
# Each Q is multiplied by product of importance_sampling of opponent and inverted sampling policy by the acting player.
# This computes counterfactual reach probabilities
sampling_policy = jnp.sum(timestep.policy * timestep.action, axis=-1, keepdims=True)
network_policy = jnp.sum(pi * timestep.action, axis=-1, keepdims=True)
# We do not take into account the player reaches, since infoset is always reached with the same prob
sampling_policy = jnp.prod(sampling_policy, axis=-2, keepdims=True)
# # TODO: what about invalid turns?
importance_sampling = network_policy / sampling_policy
importance_sampling = jnp.concatenate((jnp.ones((1, *importance_sampling.shape[1:])), importance_sampling[:-1]), axis=0)
importance_sampling = jnp.cumprod(importance_sampling, axis=0)
importance_sampling = jnp.flip(importance_sampling, axis=-2)
loss_neurd = neurd_loss(logit, pi, q_value, timestep.legal, importance_sampling)
neurd_loss_value = -jnp.sum(loss_neurd * expanded_valid) / (normalization + (normalization == 0))
return v_loss + neurd_loss_value
def non_abstracted_transformation_loss(self,
transformation_params: Params,
abstraction_params: Params,
iset_encoder_params: Params,
pi_before: chex.Array,
pi_after: chex.Array,
public_state: chex.Array,
obs: chex.Array,
legal: chex.Array,
valid: chex.Array):
return self.transformation_loss(transformation_params, pi_before, pi_after, obs, legal, valid)
def abstracted_transformation_loss(self,
transformation_params: Params,
abstraction_params: Params,
iset_encoder_params: Params,
pi_before: chex.Array,
pi_after: chex.Array,
public_state: chex.Array,
obs: chex.Array,
legal: chex.Array,
valid: chex.Array):
vectorized_abstraction = jax.vmap(self._jit_get_abstraction, in_axes=(None, None, 0, 0), out_axes=0)
current_iset = vectorized_abstraction(abstraction_params, iset_encoder_params, public_state, obs)
return self.transformation_loss(transformation_params, pi_before, pi_after, current_iset, legal, valid)
def transformation_loss(self,
transformation_params: Params,
pi_before: chex.Array,
pi_after: chex.Array,
obs: chex.Array,
legal: chex.Array,
valid: chex.Array):
vectorized_transformation = jax.vmap(self.transformation_network.apply, in_axes=(None, 0), out_axes=0)
predicted_direction = vectorized_transformation(transformation_params, obs)
update_direction = (pi_after - pi_before)
mask = legal * valid[..., jnp.newaxis]
predicted_direction = normalize_direction_with_mask(predicted_direction, mask[..., jnp.newaxis, :])
update_direction = normalize_direction_with_mask(update_direction, mask)
# TODO: This makes the whole trajectory into a single policy vector. Shall we do it this way? Maybe compare it with the old implementation
predicted_direction = transform_trajectory_to_last_dimension(predicted_direction)
update_direction = transform_trajectory_to_last_dimension(update_direction)
valid_clusters = jnp.ones(1)
loss, _ = compute_soft_kmeans_transformations(update_direction,
predicted_direction,
valid_clusters,
self.config.transformation_soft_k_means_temperature,
self.config.transformation_soft_k_means_closeness_assignment,
self.config.transformation_soft_k_means_repulsive_force)
return loss
def non_abstracted_mvs_loss(self,
mvs_params: Params,
mvs_params_target: Params,
policy_params: Params,
transformation_params: tuple[Params, Params],
abstraction_params: tuple[Params, Params],
iset_encoder_params: tuple[Params, Params],
timestep: TimeStep):
return self.mvs_loss(mvs_params, mvs_params_target, policy_params, transformation_params, timestep.obs[..., 0, :], timestep.obs[..., 1, :], timestep)
def abstracted_mvs_loss(self,
mvs_params: Params,
mvs_params_target: Params,
policy_params: Params,
transformation_params: tuple[Params, Params],
abstraction_params: tuple[Params, Params],
iset_encoder_params: tuple[Params, Params],
timestep: TimeStep):
vectorized_abstraction = jax.vmap(self._jit_get_abstraction, in_axes=(None, None, 0, 0), out_axes=0)
p1_current_iset = vectorized_abstraction(abstraction_params[0], iset_encoder_params[0], timestep.public_state, timestep.obs[..., 0, :])
p2_current_iset = vectorized_abstraction(abstraction_params[1], iset_encoder_params[1], timestep.public_state, timestep.obs[..., 1, :])
return self.mvs_loss(mvs_params, mvs_params_target, policy_params, transformation_params, p1_current_iset, p2_current_iset, timestep)
# TODO: This is only matrix-valued states now
def mvs_loss(self,
mvs_params: Params,
mvs_params_target: Params,
rnad_params: Params,
transformation_params: tuple[Params, Params],
p1_obs: chex.Array,
p2_obs: chex.Array,
timestep: TimeStep):
vectorized_policy = jax.vmap(jax.vmap(self.rnad_network.apply, in_axes=(None, 0, 0), out_axes=0), in_axes=(None, -2, -2), out_axes=-2)
vectorized_transformation = jax.vmap(self.transformation_network.apply, in_axes=(None, 0), out_axes=0)
vectorized_mvs = jax.vmap(self.mvs_network.apply, in_axes=(None, 0, 0), out_axes=0)
pi, _, _, _ = vectorized_policy(rnad_params, timestep.obs, timestep.legal)
mvs = vectorized_mvs(mvs_params, p1_obs, p2_obs)
mvs_target = vectorized_mvs(mvs_params_target, p1_obs, p2_obs)
p1_transformation_direction = vectorized_transformation(transformation_params[0], p1_obs)
p2_transformation_direction = vectorized_transformation(transformation_params[1], p2_obs)
# Dimension [Trajectory, Batch, Transformation, Player, Ations]
transformation_direction = jnp.stack((p1_transformation_direction, p2_transformation_direction), axis=-2)
transformation_direction = normalize_direction_with_mask(transformation_direction, jnp.expand_dims(timestep.legal * timestep.valid[..., jnp.newaxis, jnp.newaxis], -3))
transformation_direction = jnp.concatenate((jnp.expand_dims(jnp.zeros_like(pi), -3), transformation_direction), -3)
policy_transformations = jnp.expand_dims(pi, -3) + transformation_direction
policy_transformations = jnp.maximum(policy_transformations, 1e-12) # To invalidate negative actions and zeros.
# Invalid actions ?
policy_transformations = policy_transformations / jnp.sum(policy_transformations, axis=-1, keepdims=True)
mvs_train_target = state_v_trace(mvs_target, timestep.policy, policy_transformations, timestep.action, timestep.valid, timestep.reward, c=self.config.c_state_vtrace, rho=self.config.rho_state_vtrace)
# mask = timestep.valid[..., jnp.newaxis, jnp.newaxis]
loss_v = timestep.valid[..., jnp.newaxis, jnp.newaxis] * (mvs - lax.stop_gradient(mvs_train_target)) ** 2
normalization = jnp.sum(timestep.valid) * ((self.config.transformations + 1) ** 2)
loss_v = jnp.sum(loss_v) / (normalization + (normalization == 0))
return loss_v
def abstraction_loss(self,
abstraction_params: Params,
ps_decoder_params: Params,
iset_encoder_params: Params,
similarity_params: Params,
similarity_target: chex.Array,
public_state: chex.Array,
obs: chex.Array,
valid: chex.Array):
vectorized_abstraction = jax.vmap(self.abstraction_network.apply, in_axes=(None, 0), out_axes=0)
vectorized_ps_decoder = jax.vmap(jax.vmap(self.ps_decoder.apply, in_axes=(None, 0), out_axes=0), in_axes=(None, -2), out_axes=-2)
vectorized_iset_encoder = jax.vmap(self.iset_encoder.apply, in_axes=(None, 0), out_axes=0)