diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c690bc407d..f3d3bebef6 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -494,13 +494,14 @@ async def serve(self) -> None: self.service.set_latest_model_version(step_num) @classmethod - def get_actor(cls, config: Config): + def get_actor(cls, config: Config, runtime_env: Optional[dict] = None): """Get a Ray actor for the explorer.""" return ( ray.remote(cls) .options( name=config.explorer.name, namespace=ray.get_runtime_context().namespace, + runtime_env=runtime_env, ) .remote(config) ) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c42fccfb26..c594ac52d5 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -8,7 +8,7 @@ import time import traceback from abc import ABC, abstractmethod -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import pandas as pd import ray @@ -216,11 +216,15 @@ async def is_alive(self) -> bool: return True @classmethod - def get_actor(cls, config: Config): + def get_actor(cls, config: Config, runtime_env: Optional[dict] = None): """Get a Ray actor for the trainer.""" return ( ray.remote(cls) - .options(name=config.trainer.name, namespace=ray.get_runtime_context().namespace) + .options( + name=config.trainer.name, + namespace=ray.get_runtime_context().namespace, + runtime_env=runtime_env, + ) .remote(config) )