Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

75 changes: 75 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/caching/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Storage Service Client."""

import time
from absl import logging
import requests

SERVICE_URL = "http://service-dns/"


class StorageServiceClient:
"""Client for interacting with the Storage Tier Service."""

def __init__(self, service_url: str | None = None):
self._service_url = service_url or SERVICE_URL

def resolve(self, execution_id: int, step: int) -> str:
"""Resolves an asset path from the service."""
start = time.time()
logging.info("Resolving ID-step: %s-%s.", execution_id, step)
payload = {"execution_id": execution_id, "step": step}
response = requests.post(f"{self._service_url}/resolve", json=payload)
response.raise_for_status()
result = response.json()["path"]
end = time.time()
logging.info("Resolved %s in %s seconds.", result, end - start)
return result

def finalize(self, execution_id: int, step: int) -> None:
"""Finalizes an asset in the service."""
start = time.time()
payload = {"execution_id": execution_id, "step": step}
response = requests.post(f"{self._service_url}/finalize", json=payload)
response.raise_for_status()
end = time.time()
logging.info(
"Finalized %s %s in %s seconds.", execution_id, step, end - start
)

def prefetch(self, execution_id: int, step: int) -> None:
"""Prefetches an asset in the service."""
start = time.time()
payload = {"execution_id": execution_id, "step": step}
response = requests.post(f"{self._service_url}/prefetch", json=payload)
response.raise_for_status()
end = time.time()
logging.info(
"Prefetched %s %s in %s seconds.", execution_id, step, end - start
)

def await_transfer(self, execution_id: int, step: int) -> None:
"""Waits for any ongoing transfer for the asset to complete."""
start = time.time()
payload = {"execution_id": execution_id, "step": step}
response = requests.post(
f"{self._service_url}/await_transfer", json=payload
)
response.raise_for_status()
end = time.time()
logging.info(
"Awaited transfer %s %s in %s seconds.", execution_id, step, end - start
)
87 changes: 87 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/caching/client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the storage service client."""

import unittest
from unittest import mock
from absl.testing import absltest
from orbax.checkpoint.experimental.caching import client
import requests


class StorageServiceClientTest(unittest.TestCase):

def setUp(self):
super().setUp()
self.mock_post = self.enter_context(
mock.patch.object(requests, "post", autospec=True)
)
self.client = client.StorageServiceClient(service_url="http://test-url")

def enter_context(self, context_manager):
result = context_manager.__enter__()
self.addCleanup(context_manager.__exit__, None, None, None)
return result

def test_resolve(self):
mock_response = mock.Mock()
mock_response.json.return_value = {"path": "/path/to/asset"}
mock_response.raise_for_status.return_value = None
self.mock_post.return_value = mock_response

result = self.client.resolve(123, 456)

self.assertEqual(result, "/path/to/asset")
self.mock_post.assert_called_once_with(
"http://test-url/resolve", json={"execution_id": 123, "step": 456}
)

def test_finalize(self):
mock_response = mock.Mock()
mock_response.raise_for_status.return_value = None
self.mock_post.return_value = mock_response

self.client.finalize(123, 456)

self.mock_post.assert_called_once_with(
"http://test-url/finalize", json={"execution_id": 123, "step": 456}
)

def test_prefetch(self):
mock_response = mock.Mock()
mock_response.raise_for_status.return_value = None
self.mock_post.return_value = mock_response

self.client.prefetch(123, 456)

self.mock_post.assert_called_once_with(
"http://test-url/prefetch", json={"execution_id": 123, "step": 456}
)

def test_await_transfer(self):
mock_response = mock.Mock()
mock_response.raise_for_status.return_value = None
self.mock_post.return_value = mock_response

self.client.await_transfer(123, 456)

self.mock_post.assert_called_once_with(
"http://test-url/await_transfer",
json={"execution_id": 123, "step": 456},
)


if __name__ == "__main__":
absltest.main()
138 changes: 138 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/caching/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""FastAPI server for the Storage Service."""

from __future__ import annotations

import json

from absl import app
from absl import flags
from absl import logging
import fastapi
from orbax.checkpoint.experimental.caching import service

FLAGS = flags.FLAGS

uvicorn_app = fastapi.applications.FastAPI()
storage_service = service.StorageService(service.DEFAULT_CONFIG)


@uvicorn_app.get("/")
async def handle_hello():
try:
logging.info("Received HELLO request: on node %d", service.node_id())
return "Hello from Storage Service!"
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to handle HELLO: %s", e)
return json.dumps({"error": str(e)}), 400


@uvicorn_app.post("/resolve")
async def handle_resolve(data=fastapi.params.Body(...)):
try:
logging.info("Received RESOLVE request: on node %d", service.node_id())
execution_id = data.get("execution_id")
step = data.get("step")
asset_id = service.AssetId(execution_id=execution_id, step=step)
path = storage_service.resolve(asset_id)
return {"path": path}
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to resolve asset: %s", e)
return json.dumps({"error": str(e)}), 400


@uvicorn_app.post("/exists")
async def handle_exists(data=fastapi.params.Body(...)):
try:
logging.info("Received EXISTS request: on node %d", service.node_id())
execution_id = data.get("execution_id")
step = data.get("step")
asset_id = service.AssetId(execution_id=execution_id, step=step)
exists = storage_service.assets.exists(asset_id)
return {"exists": exists}
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to check asset existence: %s", e)
return json.dumps({"error": str(e)}), 400


@uvicorn_app.post("/finalize")
async def handle_finalize(data=fastapi.params.Body(...)):
try:
logging.info("Received FINALIZE request: on node %d", service.node_id())
execution_id = data.get("execution_id")
step = data.get("step")
asset_id = service.AssetId(execution_id=execution_id, step=step)
storage_service.finalize(asset_id)
return {"status": "ok"}
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to finalize asset: %s", e)
return json.dumps({"error": str(e)}), 400


@uvicorn_app.post("/await_transfer")
async def handle_await_transfer(data=fastapi.params.Body(...)):
"""Handles await_transfer request."""
try:
logging.info(
"Received AWAIT_TRANSFER request: on node %d", service.node_id()
)
execution_id = data.get("execution_id")
step = data.get("step")
asset_id = service.AssetId(execution_id=execution_id, step=step)
storage_service.await_transfer(asset_id)
return {"status": "ok"}
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to await transfer: %s", e)
return json.dumps({"error": str(e)}), 400


@uvicorn_app.post("/prefetch")
async def handle_prefetch(data=fastapi.params.Body(...)):
try:
logging.info("Received PREFETCH request: on node %d", service.node_id())
execution_id = data.get("execution_id")
step = data.get("step")
asset_id = service.AssetId(execution_id=execution_id, step=step)
storage_service.prefetch(asset_id)
return {"status": "ok"}
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to prefetch asset: %s", e)
return json.dumps({"error": str(e)}), 400


@uvicorn_app.get("/inspect")
async def handle_inspect():
try:
logging.info("Received INSPECT request: on node %d", service.node_id())
return {
"storages": storage_service.storages.to_json(),
"assets": storage_service.assets.to_json(),
}
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception("Failed to inspect: %s", e)
return json.dumps({"error": str(e)}), 400


def main(_):
import uvicorn # pylint: disable=g-import-not-at-top # pytype: disable=import-error

FLAGS.alsologtostderr = True
logging.set_verbosity(logging.INFO)
uvicorn.run(uvicorn_app, host="0.0.0.0", port=8080)


if __name__ == "__main__":
app.run(main)
Loading
Loading