Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions model/orbax/experimental/model/voxel2obm/main_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from collections.abc import Mapping
import os
import re
from typing import Any, Callable
from typing import Callable

from orbax.experimental.model import core as obm
from orbax.experimental.model.core.python import file_utils
from orbax.experimental.model.voxel2obm import voxel_asset_map_pb2

from .learning.brain.experimental import jax_data as jd


VOXEL_PROCESSOR_MIME_TYPE = 'application/protobuf; type=voxel.PlanProto'
VOXEL_PROCESSOR_VERSION = '0.0.1'
Expand All @@ -35,9 +37,7 @@


def voxel_plan_to_obm(
# TODO(b/447200841): use the true type hint after voxel module is
# implemented.
voxel_module: Any,
voxel_module: jd.AbstractVoxelModule,
input_signature: obm.Tree[obm.ShloTensorSpec],
output_signature: obm.Tree[obm.ShloTensorSpec],
subfolder: str = DEFAULT_VOXEL_MODULE_FOLDER,
Expand All @@ -53,11 +53,9 @@ def voxel_plan_to_obm(
Returns:
An `obm.SerializableFunction` representing the Voxel module.
"""
plan_proto = voxel_module.export_plan()
plan_proto_bytes = plan_proto.SerializeToString()

plan = voxel_module.export_plan()
unstructured_data = obm.manifest_pb2.UnstructuredData(
inlined_bytes=plan_proto_bytes,
inlined_bytes=plan.SerializeToString(),
mime_type=VOXEL_PROCESSOR_MIME_TYPE,
version=VOXEL_PROCESSOR_VERSION,
)
Expand Down Expand Up @@ -213,7 +211,7 @@ def _asset_map_to_obm_supplemental(


def voxel_global_supplemental_closure(
voxel_module: Any,
voxel_module: jd.AbstractVoxelModule,
) -> Callable[[str], Mapping[str, obm.GlobalSupplemental]] | None:
"""Returns a closure for saving Voxel assets and creating supplemental data.

Expand Down
15 changes: 7 additions & 8 deletions model/orbax/experimental/model/voxel2obm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
"""Utilities for converting Voxel signatures to OBM."""

import pprint
import numpy as np
import jax
import numpy as np
from orbax.experimental.model import core as obm
from orbax.experimental.model.voxel2obm.voxel_mock import VoxelSpec


VoxelSignature = obm.Tree[VoxelSpec]
from .learning.brain.experimental import jax_data as jd


def _obm_to_voxel_dtype(t):
Expand All @@ -32,10 +29,12 @@ def _obm_to_voxel_dtype(t):

def obm_spec_to_voxel_signature(
spec: obm.Tree[obm.ShloTensorSpec],
) -> VoxelSignature:
) -> jd.VoxelSchemaTree:
try:
return jax.tree_util.tree_map(
lambda x: VoxelSpec(shape=x.shape, dtype=_obm_to_voxel_dtype(x.dtype)),
lambda x: jd.VoxelTensorSpec(
shape=x.shape, dtype=obm.shlo_dtype_to_np_dtype(x.dtype)
),
spec,
)
except Exception as err:
Expand All @@ -52,7 +51,7 @@ def _voxel_to_obm_dtype(t) -> obm.ShloDType:


def voxel_signature_to_obm_spec(
signature: VoxelSignature,
signature: jd.VoxelSchemaTree,
) -> obm.Tree[obm.ShloTensorSpec]:
try:
return jax.tree_util.tree_map(
Expand Down
32 changes: 13 additions & 19 deletions model/orbax/experimental/model/voxel2obm/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,37 @@
import numpy as np
from orbax.experimental.model import core as obm
from orbax.experimental.model.voxel2obm import utils
from .learning.brain.experimental import jax_data as jd


class UtilsTest(parameterized.TestCase):

def test_voxel_spec_init(self):
spec = utils.VoxelSpec(shape=(1, 2), dtype=np.int32)
self.assertEqual(spec.shape, (1, 2))
self.assertEqual(spec.dtype, np.dtype('int32'))

spec = utils.VoxelSpec(shape=(1, 2), dtype=np.dtype('float32'))
self.assertEqual(spec.shape, (1, 2))
self.assertEqual(spec.dtype, np.dtype('float32'))

with self.assertRaisesRegex(
ValueError,
"""Invalid dtype: 'invalid' cannot be converted to np.dtype.""",
):
utils.VoxelSpec(shape=(1, 2), dtype='invalid')

def test_obm_spec_to_voxel_signature(self):
obm_spec = {
'a': obm.ShloTensorSpec(shape=(1, 2), dtype=obm.ShloDType.i32),
'b': obm.ShloTensorSpec(shape=(3,), dtype=obm.ShloDType.f32),
}
voxel_sig = utils.obm_spec_to_voxel_signature(obm_spec)
expected_voxel_sig = {
'a': utils.VoxelSpec(shape=(1, 2), dtype=np.int32),
'b': utils.VoxelSpec(shape=(3,), dtype=np.float32),
'a': jd.VoxelTensorSpec(
shape=(1, 2), dtype=np.dtype(np.int32)
),
'b': jd.VoxelTensorSpec(
shape=(3,), dtype=np.dtype(np.float32)
),
}

self.assertEqual(voxel_sig['a'], expected_voxel_sig['a'])
self.assertEqual(voxel_sig['b'], expected_voxel_sig['b'])

def test_voxel_signature_to_obm_spec(self):
voxel_sig = {
'a': utils.VoxelSpec(shape=(1, 2), dtype=np.int32),
'b': utils.VoxelSpec(shape=(3,), dtype=np.float32),
'a': jd.VoxelTensorSpec(
shape=(1, 2), dtype=np.dtype(np.int32)
),
'b': jd.VoxelTensorSpec(
shape=(3,), dtype=np.dtype(np.float32)
),
}
obm_spec = utils.voxel_signature_to_obm_spec(voxel_sig)
expected_obm_spec = {
Expand Down
Loading