Skip to content

Commit 8583282

Browse files
fix: restore payload_builder wrapping for dict schemas
1 parent e901417 commit 8583282

4 files changed

Lines changed: 60 additions & 0 deletions

File tree

src/taskgraph/transforms/task.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from dataclasses import dataclass
1717
from typing import Callable, Literal, Optional, Union
1818

19+
import voluptuous
20+
1921
from taskgraph.transforms.base import TransformSequence
2022
from taskgraph.util.hash import hash_path
2123
from taskgraph.util.keyed_by import evaluate_keyed_by
2224
from taskgraph.util.schema import (
2325
IndexSchema,
26+
LegacySchema,
2427
OptimizationType,
2528
Schema,
2629
TaskPriority,
@@ -194,6 +197,14 @@ class PayloadBuilder:
194197

195198

196199
def payload_builder(name, schema):
200+
if isinstance(schema, dict):
201+
schema = LegacySchema(
202+
{
203+
voluptuous.Required("implementation"): name,
204+
voluptuous.Optional("os"): str,
205+
}
206+
).extend(schema)
207+
197208
def wrap(func):
198209
assert name not in payload_builders, f"duplicate payload builder name {name}"
199210
payload_builders[name] = PayloadBuilder(schema, func)

src/taskgraph/util/schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def validate_schema(schema, obj, msg_prefix):
4848
# Handle plain Python types (e.g. str, int) via msgspec.convert
4949
elif isinstance(schema, type):
5050
msgspec.convert(obj, schema)
51+
# Handle plain dict schemas (e.g. from downstream payload builders)
52+
elif isinstance(schema, dict):
53+
voluptuous.Schema(schema)(obj)
5154
else:
5255
raise TypeError(f"Unsupported schema type: {type(schema)}")
5356
except (

test/test_transforms_task.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pprint import pprint
88

99
import pytest
10+
import voluptuous
1011
from pytest_taskgraph import FakeParameters
1112

1213
from taskgraph.transforms import task
@@ -966,3 +967,36 @@ def test_task_priority(run_transform, graph_config, test_task):
966967
assert task_dict["task"]["priority"] == priority
967968
else:
968969
assert task_dict["task"]["priority"] == graph_config["task-priority"]
970+
971+
972+
@pytest.fixture
973+
def dict_schema_builder():
974+
@task.payload_builder("test-builder", schema={"command": [str]})
975+
def _builder(config, task, task_def):
976+
pass
977+
978+
yield task.payload_builders["test-builder"].schema
979+
task.payload_builders.pop("test-builder", None)
980+
981+
982+
@pytest.mark.parametrize(
983+
"payload",
984+
(
985+
{"implementation": "test-builder", "command": ["echo"]},
986+
{"implementation": "test-builder", "command": ["echo"], "os": "linux"},
987+
),
988+
)
989+
def test_dict_schema_accepts_valid_payload(dict_schema_builder, payload):
990+
dict_schema_builder(payload)
991+
992+
993+
@pytest.mark.parametrize(
994+
"payload",
995+
(
996+
{"implementation": "wrong-name", "command": ["echo"]},
997+
{"command": ["echo"]},
998+
),
999+
)
1000+
def test_dict_schema_rejects_invalid_payload(dict_schema_builder, payload):
1001+
with pytest.raises(voluptuous.MultipleInvalid):
1002+
dict_schema_builder(payload)

test/test_util_schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,18 @@ def test_index_schema_accepts_all_fields(self):
271271
)
272272

273273

274+
class TestValidateSchemaDictHandler(unittest.TestCase):
275+
"""validate_schema must accept plain dict schemas passed
276+
by downstream payload builders without raising TypeError."""
277+
278+
def test_dict_schema_valid(self):
279+
validate_schema({"name": str, "count": int}, {"name": "a", "count": 1}, "pfx")
280+
281+
def test_dict_schema_invalid(self):
282+
with self.assertRaises(Exception):
283+
validate_schema({"name": str}, {"name": 123}, "pfx")
284+
285+
274286
def test_optionally_keyed_by():
275287
typ = optionally_keyed_by("foo", str, use_msgspec=True)
276288
assert msgspec.convert("baz", typ) == "baz"

0 commit comments

Comments
 (0)