Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _inference_process(
)

# Get the node ids on the current machine for the current node type
input_nodes = dataset.get_node_ids(node_type=args.inference_node_type)
input_nodes = dataset.fetch_node_ids(node_type=args.inference_node_type)
logger.info(
f"Rank {rank} got input nodes of shapes: {[f'{rank}: {node.shape}' for rank, node in input_nodes.items()]}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _inference_process(
# We expect that each compute machine has the same input nodes.
# As such, we shard across the compute machine cluster.
# If this is not done, then all nodes will receive the same input nodes, which is not what we want.
input_nodes = dataset.get_node_ids(
input_nodes = dataset.fetch_node_ids(
rank=args.cluster_info.compute_node_rank,
world_size=args.cluster_info.num_compute_nodes,
)
Expand Down
14 changes: 7 additions & 7 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
For Graph Store mode: `dict[int, ABLPInputNodes]`
Maps server_rank to an ABLPInputNodes dataclass containing anchor nodes,
positive labels, and negative labels with explicit node type and edge type info.
This is the return type of `RemoteDistDataset.get_ablp_input()`.
This is the return type of `RemoteDistDataset.fetch_ablp_input()`.
supervision_edge_type (Optional[Union[EdgeType, list[EdgeType]]]):
The edge type(s) to use for supervision.
For Colocated mode: Must be None iff the dataset is labeled homogeneous.
Expand Down Expand Up @@ -600,7 +600,7 @@ def _setup_for_graph_store(
Setup method for Graph Store mode.

Args:
input_nodes: ABLP input from RemoteDistDataset.get_ablp_input().
input_nodes: ABLP input from RemoteDistDataset.fetch_ablp_input().
Maps server_rank to ABLPInputNodes containing anchor nodes, positive/negative
labels with explicit node type and edge type information.
dataset: The RemoteDistDataset to sample from.
Expand All @@ -612,13 +612,13 @@ def _setup_for_graph_store(
Returns:
Tuple of (list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema).
"""
node_feature_info = dataset.get_node_feature_info()
edge_feature_info = dataset.get_edge_feature_info()
edge_types = dataset.get_edge_types()
node_feature_info = dataset.fetch_node_feature_info()
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
node_rank = dataset.cluster_info.compute_node_rank

# Get sampling ports for compute-storage connections.
sampling_ports = dataset.get_free_ports_on_storage_cluster(
sampling_ports = dataset.fetch_free_ports_on_storage_cluster(
num_ports=dataset.cluster_info.num_compute_nodes
)
sampling_port = sampling_ports[node_rank]
Expand Down Expand Up @@ -745,7 +745,7 @@ def _setup_for_graph_store(
edge_types=edge_types,
node_feature_info=node_feature_info,
edge_feature_info=edge_feature_info,
edge_dir=dataset.get_edge_dir(),
edge_dir=dataset.fetch_edge_dir(),
),
)

Expand Down
10 changes: 5 additions & 5 deletions gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,13 @@ def _setup_for_graph_store(
f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)

node_feature_info = dataset.get_node_feature_info()
edge_feature_info = dataset.get_edge_feature_info()
edge_types = dataset.get_edge_types()
node_feature_info = dataset.fetch_node_feature_info()
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
node_rank = dataset.cluster_info.compute_node_rank

# Get sampling ports for compute-storage connections.
sampling_ports = dataset.get_free_ports_on_storage_cluster(
sampling_ports = dataset.fetch_free_ports_on_storage_cluster(
num_ports=dataset.cluster_info.num_compute_nodes
)
sampling_port = sampling_ports[node_rank]
Expand Down Expand Up @@ -379,7 +379,7 @@ def _setup_for_graph_store(
edge_types=edge_types,
node_feature_info=node_feature_info,
edge_feature_info=edge_feature_info,
edge_dir=dataset.get_edge_dir(),
edge_dir=dataset.fetch_edge_dir(),
),
)

Expand Down
62 changes: 31 additions & 31 deletions gigl/distributed/graph_store/remote_dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
local_rank (int): The local rank of the process on the compute node.
mp_sharing_dict (Optional[MutableMapping[str, torch.Tensor]]):
(Optional) If provided, will be used to share tensors across the local machine.
e.g. for `get_node_ids`.
e.g. for `fetch_node_ids`.
If provided, *must* be a `DictProxy` e.g. the return value of a mp.Manager.
ex. torch.multiprocessing.Manager().dict().
"""
Expand All @@ -65,10 +65,10 @@ def __init__(
def cluster_info(self) -> GraphStoreInfo:
return self._cluster_info

def get_node_feature_info(
def fetch_node_feature_info(
self,
) -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]:
"""Get node feature information from the registered dataset.
"""Fetch node feature information from the registered dataset.

Returns:
Node feature information, which can be:
Expand All @@ -81,10 +81,10 @@ def get_node_feature_info(
DistServer.get_node_feature_info,
)

def get_edge_feature_info(
def fetch_edge_feature_info(
self,
) -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]:
"""Get edge feature information from the registered dataset.
"""Fetch edge feature information from the registered dataset.

Returns:
Edge feature information, which can be:
Expand All @@ -97,8 +97,8 @@ def get_edge_feature_info(
DistServer.get_edge_feature_info,
)

def get_edge_dir(self) -> Union[str, Literal["in", "out"]]:
"""Get the edge direction from the registered dataset.
def fetch_edge_dir(self) -> Union[str, Literal["in", "out"]]:
"""Fetch the edge direction from the registered dataset.

Returns:
The edge direction.
Expand All @@ -108,11 +108,11 @@ def get_edge_dir(self) -> Union[str, Literal["in", "out"]]:
DistServer.get_edge_dir,
)

def get_node_partition_book(
def fetch_node_partition_book(
self, node_type: Optional[NodeType] = None
) -> Optional[PartitionBook]:
"""
Gets the partition book for the specified node type.
Fetches the partition book for the specified node type.

Args:
node_type: The node type to look up. Must be ``None`` for
Expand All @@ -129,11 +129,11 @@ def get_node_partition_book(
node_type=node_type,
)

def get_edge_partition_book(
def fetch_edge_partition_book(
self, edge_type: Optional[EdgeType] = None
) -> Optional[PartitionBook]:
"""
Gets the partition book for the specified edge type.
Fetches the partition book for the specified edge type.

Args:
edge_type: The edge type to look up. Must be ``None`` for
Expand All @@ -157,7 +157,7 @@ def _infer_node_type_if_homogeneous_with_label_edges(
Auto-infers the default homogeneous node type for homogeneous datasets with label edges.
"""
if node_type is None:
node_types = self.get_node_types()
node_types = self.fetch_node_types()
if node_types is not None and DEFAULT_HOMOGENEOUS_NODE_TYPE in node_types:
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
logger.info(
Expand All @@ -173,7 +173,7 @@ def _infer_edge_type_if_homogeneous_with_label_edges(
Auto-infers the default homogeneous edge type for homogeneous datasets with label edges.
"""
if edge_type is None:
edge_types = self.get_edge_types()
edge_types = self.fetch_edge_types()
if edge_types is not None and DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_types:
edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE
logger.info(
Expand All @@ -182,7 +182,7 @@ def _infer_edge_type_if_homogeneous_with_label_edges(
)
return edge_type

def _get_node_ids(
def _fetch_node_ids(
self,
rank: Optional[int] = None,
world_size: Optional[int] = None,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _get_node_ids(
node_ids = torch.futures.wait_all(futures)
return {server_rank: node_ids for server_rank, node_ids in enumerate(node_ids)}

def get_node_ids(
def fetch_node_ids(
self,
rank: Optional[int] = None,
world_size: Optional[int] = None,
Expand Down Expand Up @@ -247,31 +247,31 @@ def get_node_ids(

Get all nodes (no split filtering, no sharding):

>>> dataset.get_node_ids()
>>> dataset.fetch_node_ids()
{
0: tensor([0, 1, 2, 3, 4, 5, 6, 7]), # All 8 nodes from storage rank 0
1: tensor([8, 9, 10, 11, 12, 13, 14, 15]) # All 8 nodes from storage rank 1
}

Shard all nodes across 2 compute nodes (compute rank 0 gets first half from each storage):

>>> dataset.get_node_ids(rank=0, world_size=2)
>>> dataset.fetch_node_ids(rank=0, world_size=2)
{
0: tensor([0, 1, 2, 3]), # First 4 of all 8 nodes from storage rank 0
1: tensor([8, 9, 10, 11]) # First 4 of all 8 nodes from storage rank 1
}

Get only training nodes (no sharding):

>>> dataset.get_node_ids(split="train")
>>> dataset.fetch_node_ids(split="train")
{
0: tensor([0, 1, 2, 3]), # 4 training nodes from storage rank 0
1: tensor([8, 9, 10, 11]) # 4 training nodes from storage rank 1
}

Combine split and sharding (training nodes, sharded for compute rank 0):

>>> dataset.get_node_ids(rank=0, world_size=2, split="train")
>>> dataset.fetch_node_ids(rank=0, world_size=2, split="train")
{
0: tensor([0, 1]), # First 2 of 4 training nodes from storage rank 0
1: tensor([8, 9]) # First 2 of 4 training nodes from storage rank 1
Expand All @@ -297,7 +297,7 @@ def server_key(server_rank: int) -> str:
logger.info(
f"Compute rank {torch.distributed.get_rank()} is getting node ids from storage nodes"
)
node_ids = self._get_node_ids(rank, world_size, node_type, split)
node_ids = self._fetch_node_ids(rank, world_size, node_type, split)
for server_rank, node_id in node_ids.items():
node_id.share_memory_()
self._mp_sharing_dict[server_key(server_rank)] = node_id
Expand All @@ -311,9 +311,9 @@ def server_key(server_rank: int) -> str:
}
return node_ids
else:
return self._get_node_ids(rank, world_size, node_type, split)
return self._fetch_node_ids(rank, world_size, node_type, split)

def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]:
def fetch_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]:
"""
Get free ports from the storage master node.

Expand Down Expand Up @@ -351,7 +351,7 @@ def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]:
logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}")
return cast(list[int], ports)

def _get_ablp_input(
def _fetch_ablp_input(
self,
split: Literal["train", "val", "test"],
rank: Optional[int] = None,
Expand Down Expand Up @@ -389,7 +389,7 @@ def _get_ablp_input(
}

# TODO(#488) - support multiple supervision edge types
def get_ablp_input(
def fetch_ablp_input(
self,
split: Literal["train", "val", "test"],
rank: Optional[int] = None,
Expand Down Expand Up @@ -439,7 +439,7 @@ def get_ablp_input(

Get training ABLP input (heterogeneous):

>>> dataset.get_ablp_input(split="train", node_type=USER, supervision_edge_type=USER_TO_ITEM)
>>> dataset.fetch_ablp_input(split="train", node_type=USER, supervision_edge_type=USER_TO_ITEM)
{
0: ABLPInputNodes(
anchor_nodes=tensor([0, 1, 2]),
Expand Down Expand Up @@ -503,7 +503,7 @@ def wrap_ablp_input(
logger.info(
f"Compute rank {torch.distributed.get_rank()} is getting ABLP input from storage nodes"
)
raw_ablp_inputs = self._get_ablp_input(
raw_ablp_inputs = self._fetch_ablp_input(
split=split,
rank=rank,
world_size=world_size,
Expand Down Expand Up @@ -551,7 +551,7 @@ def wrap_ablp_input(
)
return returned_ablp_inputs
else:
raw_inputs = self._get_ablp_input(
raw_inputs = self._fetch_ablp_input(
split=split,
rank=rank,
world_size=world_size,
Expand All @@ -572,8 +572,8 @@ def wrap_ablp_input(
) in raw_inputs.items()
}

def get_edge_types(self) -> Optional[list[EdgeType]]:
"""Get the edge types from the registered dataset.
def fetch_edge_types(self) -> Optional[list[EdgeType]]:
"""Fetch the edge types from the registered dataset.

Returns:
The edge types in the dataset, None if the dataset is homogeneous.
Expand All @@ -583,8 +583,8 @@ def get_edge_types(self) -> Optional[list[EdgeType]]:
DistServer.get_edge_types,
)

def get_node_types(self) -> Optional[list[NodeType]]:
"""Get the node types from the registered dataset.
def fetch_node_types(self) -> Optional[list[NodeType]]:
"""Fetch the node types from the registered dataset.

Returns:
The node types in the dataset, None if the dataset is homogeneous.
Expand Down
Loading