-
Notifications
You must be signed in to change notification settings - Fork 693
Adapt to dlsime v0.0.2 #4242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt to dlsime v0.0.2 #4242
Changes from all commits
e5d2b35
905d22f
2603e03
7776acd
6cdb3bd
7f2f652
b728693
de3c3f4
2cb62af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,8 @@ | |
| import asyncio | ||
| import json | ||
| import os | ||
| from typing import Dict, List | ||
| from typing import Dict | ||
|
|
||
| from dlslime import Assignment as DLSlimeAssignment | ||
| from dlslime import NVLinkEndpoint, RDMAEndpoint, available_nic | ||
|
|
||
| from lmdeploy.logger import get_logger | ||
|
|
@@ -20,97 +19,74 @@ | |
| LMDEPLOY_USE_ASYNC_MIGRATION = os.environ.get('LMDEPLOY_USE_ASYNC_MIGRATION', None) | ||
|
|
||
|
|
||
| async def read_batch_coroutine(endpoint: RDMAEndpoint, batch: List[DLSlimeAssignment]): | ||
| loop = asyncio.get_running_loop() | ||
| future = loop.create_future() | ||
|
|
||
| def _completion_handler(status: int): | ||
| loop.call_soon_threadsafe(future.set_result, status) | ||
|
|
||
| endpoint.read_batch_with_callback( | ||
| batch, | ||
| _completion_handler, | ||
| ) | ||
| await future | ||
|
|
||
|
|
||
| class DLSlimeMigrationManagement: | ||
|
|
||
| def __init__(self, init_request: DistServeInitRequest): | ||
| self.rank = init_request.rank | ||
| self.local_engine_config: DistServeEngineConfig = init_request.local_engine_config | ||
| self.remote_engine_config: DistServeEngineConfig = init_request.remote_engine_config | ||
| self.endpoint: Dict[MigrationProtocol, RDMAEndpoint] = { | ||
| MigrationProtocol.TCP: None, | ||
| MigrationProtocol.RDMA: None, | ||
| MigrationProtocol.NVLINK: None, | ||
| } | ||
| self.local_engine_config: DistServeEngineConfig = (init_request.local_engine_config) | ||
| self.remote_engine_config: DistServeEngineConfig = (init_request.remote_engine_config) | ||
| self.endpoint: Dict[MigrationProtocol, RDMAEndpoint | NVLinkEndpoint] = {} | ||
| if init_request.protocol == MigrationProtocol.RDMA: | ||
| nics = available_nic() | ||
| device_name = nics[self.rank % len(nics)] | ||
| logger.info(f'use device {device_name} for kv migration') | ||
| self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint(device_name=device_name, | ||
| ib_port=1, | ||
| link_type=init_request.rdma_config.link_type.name) | ||
| self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint( | ||
| device_name=device_name, | ||
| ib_port=1, | ||
| link_type=init_request.rdma_config.link_type.name, | ||
| ) | ||
| elif init_request.protocol == MigrationProtocol.NVLINK: | ||
| self.endpoint[MigrationProtocol.NVLINK] = NVLinkEndpoint() | ||
|
|
||
| def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): | ||
| self.endpoint[register_mr_request.protocol].register_memory_region(register_mr_request.mr_key, | ||
| register_mr_request.addr, | ||
| register_mr_request.offset, | ||
| register_mr_request.length) | ||
| self.endpoint[register_mr_request.protocol].register_memory_region( | ||
| register_mr_request.mr_key, | ||
| register_mr_request.addr, | ||
| register_mr_request.offset, | ||
| register_mr_request.length, | ||
| ) | ||
|
|
||
| def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo): | ||
| self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info)) | ||
|
|
||
| async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False): | ||
| batch = [ | ||
| DLSlimeAssignment( | ||
| mr_key=assign.mr_key, | ||
| target_offset=assign.target_offset, | ||
| source_offset=assign.source_offset, | ||
| length=assign.length, | ||
| ) for assign in assignment.batch | ||
| ] | ||
|
|
||
| if not LMDEPLOY_USE_ASYNC_MIGRATION: | ||
| MAX_NUM_READ_BATCH = 4096 | ||
|
|
||
| def split(batch: List[DLSlimeAssignment]): | ||
| batch_split = [] | ||
| for i in range(0, len(batch), MAX_NUM_READ_BATCH): | ||
| batch_split.append(batch[i:i + MAX_NUM_READ_BATCH]) | ||
| return batch_split | ||
|
|
||
| batch_splited = split(batch) | ||
| for b_split in batch_splited: | ||
| self.endpoint[assignment.protocol].read_batch(b_split) | ||
| async def p2p_migrate(self, assignment: MigrationAssignment): | ||
| batch = [( | ||
| assign.mr_key, | ||
| assign.mr_key, | ||
|
Comment on lines
+54
to
+55
|
||
| assign.target_offset, | ||
| assign.source_offset, | ||
| assign.length, | ||
| ) for assign in assignment.batch] | ||
|
|
||
| future = self.endpoint[assignment.protocol].read(batch) | ||
| if LMDEPLOY_USE_ASYNC_MIGRATION: | ||
| loop = asyncio.get_running_loop() | ||
| return await loop.run_in_executor(None, future.wait) | ||
| else: | ||
| await read_batch_coroutine(self.endpoint[assignment.protocol], batch) | ||
| return future.wait() | ||
|
|
||
|
|
||
| @MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name) | ||
| class DLSlimeBackend(MigrationBackendImpl): | ||
| """DLSlime Transfer Engine.""" | ||
|
|
||
| def __init__(self): | ||
| self.links: Dict[int, DLSlimeMigrationManagement] = {} | ||
| self.links: Dict[str, DLSlimeMigrationManagement] = {} | ||
|
|
||
| def p2p_initialize(self, init_request: DistServeInitRequest): | ||
| self.links[init_request.remote_engine_id] = DLSlimeMigrationManagement(init_request) | ||
|
|
||
| def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): | ||
| self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request) | ||
|
|
||
| def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): | ||
| return self.links[remote_engine_id].endpoint[protocol].endpoint_info | ||
| def endpoint_info(self, remote_engine_id: str, protocol: MigrationProtocol): | ||
|
||
| return self.links[remote_engine_id].endpoint[protocol].endpoint_info() | ||
|
|
||
| def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo): | ||
| self.links[remote_engine_id].connect(conn_req) | ||
|
|
||
| async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): | ||
| await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op) | ||
| await self.links[assignment.remote_engine_id].p2p_migrate(assignment) | ||
|
|
||
| def store(self, assignment: MigrationAssignment, async_op: bool = False): | ||
| raise NotImplementedError | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parentheses around the assignment values are unnecessary and don't serve any purpose. They can be removed for cleaner code.