From 8cf03a3071f6c2a01386823ece8b5274a053247f Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Mon, 18 May 2026 17:35:41 +0200 Subject: [PATCH 1/6] LookupTable per RunType --- .../essreduce/src/ess/reduce/unwrap/lut.py | 82 +++++++++++++------ .../src/ess/reduce/unwrap/to_wavelength.py | 31 ++++--- .../essreduce/src/ess/reduce/unwrap/types.py | 12 ++- .../src/ess/reduce/unwrap/workflow.py | 6 +- 4 files changed, 86 insertions(+), 45 deletions(-) diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index c3a4a14aa..10007319c 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -5,15 +5,18 @@ """ import warnings +from collections.abc import Callable from dataclasses import dataclass from typing import NewType import numpy as np import sciline as sl import scipp as sc +import scippnexus as snx +from scippneutron.chopper import DiskChopper from scippneutron.tof import chopper_cascade -from ..nexus.types import AnyRun, DiskChoppers +from ..nexus.types import DiskChoppers, Position, RunType from .types import LookupTable @@ -52,7 +55,7 @@ def __post_init__(self): @dataclass -class SimulationResults: +class SimulationResultsBaseClass: """ Results of a time-of-flight simulation used to create a lookup table. It should contain readings at various positions along the beamline, e.g., at @@ -73,7 +76,31 @@ class SimulationResults: """ readings: dict[str, BeamlineComponentReading] - choppers: DiskChoppers[AnyRun] | None = None + choppers: dict[str, DiskChopper] | None = None + + +class SimulationResults( + sl.Scope[RunType, SimulationResultsBaseClass], + SimulationResultsBaseClass, +): + """ + Results of a time-of-flight simulation used to create a lookup table. + It should contain readings at various positions along the beamline, e.g., at + the source and after each chopper. + It also contains the chopper parameters used in the simulation, so it can be + determined if this simulation is compatible with a given experiment. + + Parameters + ---------- + readings: + A dict of :class:`BeamlineComponentReading` objects representing the readings at + various positions along the beamline. The keys in the dict should correspond to + the names of the components (e.g., 'source', 'chopper1', etc.). + choppers: + The chopper parameters used in the simulation (if any). These are used to verify + that the simulation is compatible with a given experiment (comparing chopper + openings, frequencies, phases, etc.). + """ NumberOfSimulatedNeutrons = NewType("NumberOfSimulatedNeutrons", int) @@ -151,11 +178,13 @@ class SourceBounds: """Wavelength range (min, max) of the neutrons in the source pulse.""" -ChopperFrameSequence = NewType("ChopperFrameSequence", chopper_cascade.FrameSequence) -""" -Sequence of chopper frames used to compute the wavelength as a function of distance and -event_time_offset in the lookup table. -""" +class ChopperFrameSequence( + sl.Scope[RunType, chopper_cascade.FrameSequence], chopper_cascade.FrameSequence +): + """ + Sequence of chopper frames used to compute the wavelength as a function of distance + and event_time_offset in the lookup table. + """ def _compute_mean_wavelength( @@ -224,13 +253,13 @@ def _compute_mean_wavelength( def make_wavelength_lookup_table( - simulation: SimulationResults, + simulation: SimulationResults[RunType], ltotal_range: LtotalRange, distance_resolution: DistanceResolution, time_resolution: TimeResolution, pulse_period: PulsePeriod, pulse_stride: PulseStride, -) -> LookupTable: +) -> LookupTable[RunType]: """ Compute a lookup table for wavelength as a function of distance and time-of-arrival. @@ -378,7 +407,7 @@ def make_wavelength_lookup_table( }, ) - return LookupTable( + return LookupTable[RunType]( array=table, pulse_period=pulse_period, pulse_stride=pulse_stride, @@ -393,7 +422,7 @@ def make_wavelength_lookup_table( ) -def _to_component_reading(component): +def _to_component_reading(component) -> BeamlineComponentReading: events = component.data.squeeze().flatten(to='event') sel = sc.full(value=True, sizes=events.sizes) for key in {'blocked_by_others', 'blocked_by_me'} & set(events.masks.keys()): @@ -412,13 +441,13 @@ def _to_component_reading(component): def simulate_chopper_cascade_using_tof( - choppers: DiskChoppers[AnyRun], + choppers: DiskChoppers[RunType], source_position: SourcePosition, neutrons: NumberOfSimulatedNeutrons, pulse_stride: PulseStride, seed: SimulationSeed, facility: SimulationFacility, -) -> SimulationResults: +) -> SimulationResults[RunType]: """ Simulate a pulse of neutrons propagating through a chopper cascade using the ``tof`` package (https://scipp.github.io/tof). @@ -457,12 +486,12 @@ def simulate_chopper_cascade_using_tof( ) sim_readings = {"source": _to_component_reading(source)} if not tof_choppers: - return SimulationResults(readings=sim_readings, choppers=None) + return SimulationResults[RunType](readings=sim_readings, choppers=None) model = tof.Model(source=source, choppers=tof_choppers) results = model.run() for name, ch in results.choppers.items(): sim_readings[name] = _to_component_reading(ch) - return SimulationResults(readings=sim_readings, choppers=choppers) + return SimulationResults[RunType](readings=sim_readings, choppers=choppers) def LookupTableWorkflow(): @@ -617,11 +646,11 @@ def _estimate_wavelength_by_polygon_centers( def compute_frame_sequence( pulse_period: PulsePeriod, - disk_choppers: DiskChoppers[AnyRun], - source_position: SourcePosition, + disk_choppers: DiskChoppers[RunType], + source_position: Position[snx.NXsource, RunType], source_bounds: SourceBounds, pulse_stride: PulseStride, -) -> ChopperFrameSequence: +) -> ChopperFrameSequence[RunType]: """ Compute the chopper frame sequence for a given set of disk choppers and source pulse parameters. @@ -680,7 +709,7 @@ def compute_frame_sequence( npulses=pulse_stride, ) frames = frames.chop(chops.values()) - return ChopperFrameSequence(frames) + return ChopperFrameSequence[RunType](frames) def make_wavelength_lut_from_polygons( @@ -690,7 +719,7 @@ def make_wavelength_lut_from_polygons( pulse_period: PulsePeriod, pulse_stride: PulseStride, frames: ChopperFrameSequence, -) -> LookupTable: +) -> LookupTable[RunType]: """ Compute a lookup table for wavelength as a function of distance and time-of-arrival. @@ -776,7 +805,7 @@ def make_wavelength_lut_from_polygons( coords={"distance": distances, "event_time_offset": time_edges}, ) - return LookupTable( + return LookupTable[RunType]( array=table, pulse_period=pulse_period, pulse_stride=pulse_stride, @@ -787,13 +816,20 @@ def make_wavelength_lut_from_polygons( ) +def providers() -> tuple[Callable]: + """ + Return the providers for creating the wavelength lookup table. + """ + return (make_wavelength_lut_from_polygons, compute_frame_sequence) + + def FastLookupTableWorkflow(): """ Create a workflow for computing a wavelength lookup table from computing an acceptance diagram for a pulse propagating through a chopper cascade. """ wf = sl.Pipeline( - (make_wavelength_lut_from_polygons, compute_frame_sequence), + providers(), params={ PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), PulseStride: 1, diff --git a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py index e3a40e777..7d4860bbc 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py +++ b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py @@ -40,6 +40,7 @@ ErrorLimitedLookupTable, LookupTable, LookupTableRelativeErrorThreshold, + Lut, MonitorLtotal, PulseStrideOffset, WavelengthDetector, @@ -137,7 +138,7 @@ def __call__( def _compute_wavelength_histogram( - da: sc.DataArray, lookup: ErrorLimitedLookupTable, ltotal: sc.Variable + da: sc.DataArray, lookup: Lut, ltotal: sc.Variable ) -> sc.DataArray: # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files, # it may be called 'tof' or 'frame_time'. @@ -243,7 +244,7 @@ def _guess_pulse_stride_offset( def _prepare_wavelength_interpolation_inputs( da: sc.DataArray, - lookup: ErrorLimitedLookupTable, + lookup: Lut, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> dict: @@ -336,7 +337,7 @@ def _prepare_wavelength_interpolation_inputs( def _compute_wavelength_events( da: sc.DataArray, - lookup: ErrorLimitedLookupTable, + lookup: Lut, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> sc.DataArray: @@ -435,9 +436,7 @@ def monitor_ltotal_from_straight_line_approximation( ) -def _mask_large_uncertainty_in_lut( - table: LookupTable, error_threshold: float -) -> LookupTable: +def _mask_large_uncertainty_in_lut(table: Lut, error_threshold: float) -> Lut: """ Mask regions in the lookup table with large uncertainty using NaNs. @@ -452,7 +451,7 @@ def _mask_large_uncertainty_in_lut( da = table.array relative_error = sc.stddevs(da.data) / sc.values(da.data) mask = relative_error > sc.scalar(error_threshold) - return LookupTable( + return LookupTable[RunType]( **{ **asdict(table), "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), @@ -461,10 +460,10 @@ def _mask_large_uncertainty_in_lut( def mask_large_uncertainty_in_lut_detector( - table: LookupTable, + table: LookupTable[RunType], error_threshold: LookupTableRelativeErrorThreshold, detector_name: NeXusDetectorName, -) -> ErrorLimitedLookupTable[snx.NXdetector]: +) -> ErrorLimitedLookupTable[RunType, snx.NXdetector]: """ Mask regions in the wavelength lookup table with large uncertainty using NaNs. @@ -479,7 +478,7 @@ def mask_large_uncertainty_in_lut_detector( Name of the detector for which to apply the error threshold. This is used to get the correct error threshold from the dictionary of error thresholds. """ - return ErrorLimitedLookupTable[snx.NXdetector]( + return ErrorLimitedLookupTable[RunType, snx.NXdetector]( _mask_large_uncertainty_in_lut( table=table, error_threshold=error_threshold[detector_name] ) @@ -487,10 +486,10 @@ def mask_large_uncertainty_in_lut_detector( def mask_large_uncertainty_in_lut_monitor( - table: LookupTable, + table: LookupTable[RunType], error_threshold: LookupTableRelativeErrorThreshold, monitor_name: NeXusName[MonitorType], -) -> ErrorLimitedLookupTable[MonitorType]: +) -> ErrorLimitedLookupTable[RunType, MonitorType]: """ Mask regions in the wavelength lookup table with large uncertainty using NaNs. @@ -505,7 +504,7 @@ def mask_large_uncertainty_in_lut_monitor( Name of the monitor for which to apply the error threshold. This is used to get the correct error threshold from the dictionary of error thresholds. """ - return ErrorLimitedLookupTable[MonitorType]( + return ErrorLimitedLookupTable[RunType, MonitorType]( _mask_large_uncertainty_in_lut( table=table, error_threshold=error_threshold[monitor_name] ) @@ -514,7 +513,7 @@ def mask_large_uncertainty_in_lut_monitor( def _compute_wavelength_data( da: sc.DataArray, - lookup: ErrorLimitedLookupTable[Component], + lookup: ErrorLimitedLookupTable[RunType, Component], ltotal: sc.Variable, pulse_stride_offset: int, ) -> sc.DataArray: @@ -533,7 +532,7 @@ def _compute_wavelength_data( def detector_wavelength_data( detector_data: RawDetector[RunType], - lookup: ErrorLimitedLookupTable[snx.NXdetector], + lookup: ErrorLimitedLookupTable[RunType, snx.NXdetector], ltotal: DetectorLtotal[RunType], pulse_stride_offset: PulseStrideOffset, ) -> WavelengthDetector[RunType]: @@ -568,7 +567,7 @@ def detector_wavelength_data( def monitor_wavelength_data( monitor_data: RawMonitor[RunType, MonitorType], - lookup: ErrorLimitedLookupTable[MonitorType], + lookup: ErrorLimitedLookupTable[RunType, MonitorType], ltotal: MonitorLtotal[RunType, MonitorType], pulse_stride_offset: PulseStrideOffset, ) -> WavelengthMonitor[RunType, MonitorType]: diff --git a/packages/essreduce/src/ess/reduce/unwrap/types.py b/packages/essreduce/src/ess/reduce/unwrap/types.py index 86dd086c9..babcccd5f 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/types.py +++ b/packages/essreduce/src/ess/reduce/unwrap/types.py @@ -15,9 +15,10 @@ @dataclass -class LookupTable: +class Lut: """ - Lookup table giving wavelength as a function of distance and ``event_time_offset``. + Base class for a lookup table giving wavelength as a function of distance and + ``event_time_offset``. """ array: sc.DataArray @@ -44,7 +45,12 @@ def plot(self, *args, **kwargs) -> Any: return self.array.plot(*args, **kwargs) -class ErrorLimitedLookupTable(sl.Scope[Component, LookupTable], LookupTable): +class LookupTable(sl.Scope[RunType, Lut], Lut): + """Lookup table giving wavelength as a function of distance and + ``event_time_offset``.""" + + +class ErrorLimitedLookupTable(sl.Scope[RunType, Component, Lut], Lut): """Lookup table that is masked with NaNs in regions where the standard deviation of the wavelength is above a certain threshold.""" diff --git a/packages/essreduce/src/ess/reduce/unwrap/workflow.py b/packages/essreduce/src/ess/reduce/unwrap/workflow.py index b493162f5..c43585cf1 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/workflow.py +++ b/packages/essreduce/src/ess/reduce/unwrap/workflow.py @@ -6,7 +6,7 @@ import scipp as sc from ..nexus import GenericNeXusWorkflow -from . import to_wavelength +from . import lut, to_wavelength from .types import LookupTable, LookupTableFilename, PulseStrideOffset @@ -82,10 +82,10 @@ def GenericUnwrapWorkflow( """ wf = GenericNeXusWorkflow(run_types=run_types, monitor_types=monitor_types) - for provider in to_wavelength.providers(): + for provider in (*to_wavelength.providers(), *lut.providers()): wf.insert(provider) - wf.insert(load_lookup_table) + # wf.insert(load_lookup_table) # Default parameters wf[PulseStrideOffset] = None From af767197538f4df9a3351538a5cd5e54a3ca27d3 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Mon, 18 May 2026 18:45:05 +0200 Subject: [PATCH 2/6] add provider for DiskChoppers --- .../src/ess/reduce/nexus/workflow.py | 18 +++++-- .../essreduce/src/ess/reduce/unwrap/lut.py | 52 ++++++++++--------- .../src/ess/reduce/unwrap/to_wavelength.py | 10 ++-- .../src/ess/reduce/unwrap/workflow.py | 2 + 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/packages/essreduce/src/ess/reduce/nexus/workflow.py b/packages/essreduce/src/ess/reduce/nexus/workflow.py index dcd0a7988..6e320d53f 100644 --- a/packages/essreduce/src/ess/reduce/nexus/workflow.py +++ b/packages/essreduce/src/ess/reduce/nexus/workflow.py @@ -15,7 +15,7 @@ import scippnexus as snx from scipp.constants import g from scipp.core import label_based_index_to_positional_index -from scippneutron.chopper import extract_chopper_from_nexus +from scippneutron.chopper import DiskChopper, extract_chopper_from_nexus from scippneutron.metadata import RadiationProbe, SourceType from . import _nexus_loader as nexus @@ -26,6 +26,7 @@ Component, DetectorBankSizes, DetectorPositionOffset, + DiskChoppers, DynamicPosition, EmptyDetector, EmptyMonitor, @@ -537,7 +538,7 @@ def assemble_monitor_data( return RawMonitor[RunType, MonitorType](_add_variances(da)) -def parse_disk_choppers( +def parse_raw_choppers( choppers: AllNeXusComponents[snx.NXdisk_chopper, RunType], ) -> RawChoppers[RunType]: """Convert the NeXus representation of a chopper to ours. @@ -559,6 +560,17 @@ def parse_disk_choppers( ) +def to_disk_choppers(choppers: RawChoppers[RunType]) -> DiskChoppers[RunType]: + disk_choppers = { + # If there is no beam_position, we set it to 0 by default. + key: DiskChopper.from_nexus( + {**{'beam_position': sc.scalar(0.0, unit='deg')}, **ch} + ) + for key, ch in choppers.items() + } + return DiskChoppers[RunType](disk_choppers) + + def load_proton_charge( parent_location: NeXusComponentLocationSpec[ProductionInfo, RunType], interval: TimeInterval[RunType], @@ -789,7 +801,7 @@ def load_source_metadata_from_nexus( assemble_detector_data, ) -_chopper_providers = (parse_disk_choppers,) +_chopper_providers = (parse_raw_choppers, to_disk_choppers) _metadata_providers = ( load_beamline_metadata_from_nexus, diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index 10007319c..335769342 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -17,7 +17,7 @@ from scippneutron.tof import chopper_cascade from ..nexus.types import DiskChoppers, Position, RunType -from .types import LookupTable +from .types import LookupTable, Lut @dataclass @@ -806,13 +806,16 @@ def make_wavelength_lut_from_polygons( ) return LookupTable[RunType]( - array=table, - pulse_period=pulse_period, - pulse_stride=pulse_stride, - distance_resolution=table.coords["distance"][1] - table.coords["distance"][0], - time_resolution=table.coords["event_time_offset"][1] - - table.coords["event_time_offset"][0], - # TODO: Do we still want to store the chopper information in the lookup table? + Lut( + array=table, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + distance_resolution=table.coords["distance"][1] + - table.coords["distance"][0], + time_resolution=table.coords["event_time_offset"][1] + - table.coords["event_time_offset"][0], + # TODO: Do we still want to store the chopper information in the lookup table? + ) ) @@ -823,25 +826,26 @@ def providers() -> tuple[Callable]: return (make_wavelength_lut_from_polygons, compute_frame_sequence) +def default_parameters() -> dict: + return { + PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), + PulseStride: 1, + DistanceResolution: sc.scalar(0.1, unit="m"), + TimeResolution: sc.scalar(250.0, unit='us'), + SourceBounds: SourceBounds( + time=(sc.scalar(0.0, unit='ms'), sc.scalar(5.0, unit='ms')), + wavelength=( + sc.scalar(0.0, unit='angstrom'), + sc.scalar(15.0, unit='angstrom'), + ), + ), + } + + def FastLookupTableWorkflow(): """ Create a workflow for computing a wavelength lookup table from computing an acceptance diagram for a pulse propagating through a chopper cascade. """ - wf = sl.Pipeline( - providers(), - params={ - PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), - PulseStride: 1, - DistanceResolution: sc.scalar(0.1, unit="m"), - TimeResolution: sc.scalar(250.0, unit='us'), - SourceBounds: SourceBounds( - time=(sc.scalar(0.0, unit='ms'), sc.scalar(5.0, unit='ms')), - wavelength=( - sc.scalar(0.0, unit='angstrom'), - sc.scalar(15.0, unit='angstrom'), - ), - ), - }, - ) + wf = sl.Pipeline(providers(), params=default_parameters()) return wf diff --git a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py index 7d4860bbc..93a832ec7 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py +++ b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py @@ -452,10 +452,12 @@ def _mask_large_uncertainty_in_lut(table: Lut, error_threshold: float) -> Lut: relative_error = sc.stddevs(da.data) / sc.values(da.data) mask = relative_error > sc.scalar(error_threshold) return LookupTable[RunType]( - **{ - **asdict(table), - "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), - } + Lut( + **{ + **asdict(table), + "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), + } + ) ) diff --git a/packages/essreduce/src/ess/reduce/unwrap/workflow.py b/packages/essreduce/src/ess/reduce/unwrap/workflow.py index c43585cf1..9da6b710c 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/workflow.py +++ b/packages/essreduce/src/ess/reduce/unwrap/workflow.py @@ -89,5 +89,7 @@ def GenericUnwrapWorkflow( # Default parameters wf[PulseStrideOffset] = None + for key, value in lut.default_parameters().items(): + wf[key] = value return wf From fe966561d029283a42ad6f08e6736a499f16eea6 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Thu, 21 May 2026 16:50:39 +0200 Subject: [PATCH 3/6] update unwrap notebook to integrate onthefly lut computation --- .../user-guide/unwrap/analytical-unwrap.ipynb | 165 ++++++------------ .../essreduce/src/ess/reduce/unwrap/lut.py | 5 +- 2 files changed, 60 insertions(+), 110 deletions(-) diff --git a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb index a6e3dfc19..68c872fa7 100644 --- a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb +++ b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb @@ -33,7 +33,7 @@ "import scipp as sc\n", "import scippnexus as snx\n", "from scippneutron.chopper import DiskChopper\n", - "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName\n", + "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName, Position\n", "from ess.reduce.unwrap import *" ] }, @@ -293,7 +293,7 @@ "source": [ "### Computing neutron wavelengths\n", "\n", - "Next, we use a workflow that provides an estimate of the neutron wavelength as a function of neutron time-of-arrival.\n", + "Next, we use a workflow that computes the neutron wavelength from the neutron time-of-arrival.\n", "\n", "#### Setting up the workflow" ] @@ -309,8 +309,11 @@ "\n", "wf[RawDetector[SampleRun]] = raw_data\n", "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "wf[Position[snx.NXsource, SampleRun]] = source_position\n", "wf[NeXusDetectorName] = 'dream_detector'\n", + "wf[DiskChoppers[SampleRun]] = disk_choppers\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': float(\"inf\")}\n", + "wf[LtotalRange] = (sc.scalar(5.0, unit='m'), sc.scalar(80.0, unit='m'))\n", "\n", "wf.visualize(WavelengthDetector[SampleRun])" ] @@ -319,77 +322,36 @@ "cell_type": "markdown", "id": "22", "metadata": {}, - "source": [ - "By default, the workflow tries to load a `LookupTable` from a file.\n", - "\n", - "In this notebook, instead of using such a pre-made file,\n", - "we will build our own lookup table from the chopper information and apply it to the workflow." - ] - }, - { - "cell_type": "markdown", - "id": "23", - "metadata": {}, - "source": [ - "#### Building the wavelength lookup table\n", - "\n", - "We use [`scippneutron.tof.chopper_cascade`](https://scipp.github.io/scippneutron/user-guide/chopper/chopper-cascade.html) module to propagate a pulse of neutrons through the chopper system to the detectors,\n", - "and predict the most likely neutron wavelength for a given time-of-arrival and distance from source.\n", - "\n", - "From this,\n", - "we build a lookup table on which bilinear interpolation is used to compute a wavelength for every neutron event." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "lut_wf = LookupTableWorkflow(use_simulation=False)\n", - "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", - "lut_wf[SourcePosition] = source_position\n", - "lut_wf[LtotalRange] = (\n", - " sc.scalar(25.0, unit=\"m\"),\n", - " sc.scalar(80.0, unit=\"m\"),\n", - ")\n", - "lut_wf.visualize(LookupTable)" - ] - }, - { - "cell_type": "markdown", - "id": "25", - "metadata": {}, "source": [ "#### Inspecting the lookup table\n", "\n", - "The workflow first runs a calculation propagating a pulse of neutrons (represented by a polygon in time and wavelength space),\n", + "The workflow first uses the [`scippneutron.tof.chopper_cascade`](https://scipp.github.io/scippneutron/user-guide/chopper/chopper-cascade.html)\n", + "module to propagate a pulse of neutrons (represented by a polygon in time and wavelength space)\n", "through a chopper cascade defined by the chopper parameters above.\n", "\n", "This can be used to create a figure displaying the neutron wavelengths,\n", "as a function of arrival time at the detector.\n", "\n", - "This is the basis for creating our lookup table." + "This is the basis for creating our lookup table, on which bilinear interpolation is used further down to compute a wavelength for every neutron event." ] }, { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "23", "metadata": {}, "outputs": [], "source": [ "dist = sc.scalar(60.0, unit='m')\n", "\n", - "frames = lut_wf.compute(ChopperFrameSequence)\n", + "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(dist)\n", "fig, ax = at_detector.draw()" ] }, { "cell_type": "markdown", - "id": "27", + "id": "24", "metadata": {}, "source": [ "The source pulse is defined as spanning 0-5 ms in time, and 0-15 Å in wavelength,\n", @@ -412,14 +374,14 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "25", "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(LookupTable)\n", + "table = wf.compute(LookupTable[SampleRun])\n", "\n", "# Overlay LUT prediction on the polygons figure\n", - "da = table.array[\"distance\", 352]\n", + "da = table.array[\"distance\", 552]\n", "ax.plot(\n", " da.coords['event_time_offset'].values / 1000,\n", " da.values,\n", @@ -433,7 +395,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -443,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "27", "metadata": {}, "source": [ "The full table covers a range of distances, and looks like" @@ -452,7 +414,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -461,24 +423,21 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "29", "metadata": {}, "source": [ "#### Computing a wavelength coordinate\n", "\n", - "We will now update our workflow, and use it to obtain our event data with a wavelength coordinate:" + "We will now use our workflow once again to obtain our event data with a wavelength coordinate (a step further down in the pipeline):" ] }, { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "30", "metadata": {}, "outputs": [], "source": [ - "# Set the computed lookup table onto the original workflow\n", - "wf[LookupTable] = table\n", - "\n", "# Compute wavelength of neutron events\n", "wavs = wf.compute(WavelengthDetector[SampleRun])\n", "edges = sc.linspace(\"wavelength\", 0.8, 4.6, 201, unit=\"angstrom\")\n", @@ -489,7 +448,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "31", "metadata": {}, "source": [ "#### Comparing to the ground truth\n", @@ -501,7 +460,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -519,7 +478,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "33", "metadata": {}, "source": [ "### Multiple detector pixels\n", @@ -535,7 +494,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "34", "metadata": {}, "outputs": [], "source": [ @@ -553,7 +512,7 @@ }, { "cell_type": "markdown", - "id": "38", + "id": "35", "metadata": {}, "source": [ "Our raw data has now a `detector_number` dimension of length 2.\n", @@ -564,7 +523,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -579,7 +538,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "37", "metadata": {}, "source": [ "Computing wavelength is done in the same way as above.\n", @@ -589,7 +548,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -623,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "42", + "id": "39", "metadata": {}, "source": [ "### Handling time overlap between subframes\n", @@ -643,7 +602,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -674,7 +633,7 @@ }, { "cell_type": "markdown", - "id": "44", + "id": "41", "metadata": {}, "source": [ "We can now see that there is no longer a gap between the two frames at the center of each pulse (green region).\n", @@ -686,14 +645,14 @@ { "cell_type": "code", "execution_count": null, - "id": "45", + "id": "42", "metadata": {}, "outputs": [], "source": [ "# Update workflow\n", - "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", + "wf[DiskChoppers[AnyRun]] = disk_choppers\n", "\n", - "frames = lut_wf.compute(ChopperFrameSequence)\n", + "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(dist)\n", "fig, ax = at_detector.draw()\n", "ax.set(xlim=(36, 44), ylim=(2, 3))" @@ -701,7 +660,7 @@ }, { "cell_type": "markdown", - "id": "46", + "id": "43", "metadata": {}, "source": [ "The data in the lookup table contains both the mean wavelength for each distance and time-of-arrival bin,\n", @@ -719,11 +678,11 @@ { "cell_type": "code", "execution_count": null, - "id": "47", + "id": "44", "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(LookupTable)\n", + "table = wf.compute(LookupTable[SampleRun])\n", "table.plot(ymin=65) / (sc.stddevs(table.array) / sc.values(table.array)).plot(\n", " norm=\"linear\", ymin=55, vmax=0.05\n", ")" @@ -731,7 +690,7 @@ }, { "cell_type": "markdown", - "id": "48", + "id": "45", "metadata": {}, "source": [ "The workflow has a parameter which is used to mask out regions where the standard deviation is above a certain threshold.\n", @@ -745,21 +704,19 @@ { "cell_type": "code", "execution_count": null, - "id": "49", + "id": "46", "metadata": {}, "outputs": [], "source": [ - "wf[LookupTable] = table\n", - "\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': 0.02}\n", "\n", - "masked_table = wf.compute(ErrorLimitedLookupTable[snx.NXdetector])\n", + "masked_table = wf.compute(ErrorLimitedLookupTable[SampleRun, snx.NXdetector])\n", "masked_table.plot(ymin=65)" ] }, { "cell_type": "markdown", - "id": "50", + "id": "47", "metadata": {}, "source": [ "We can now see that the central region is masked out.\n", @@ -774,7 +731,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -799,7 +756,7 @@ }, { "cell_type": "markdown", - "id": "52", + "id": "49", "metadata": {}, "source": [ "## The ODIN instrument\n", @@ -819,7 +776,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53", + "id": "50", "metadata": {}, "outputs": [], "source": [ @@ -932,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54", + "id": "51", "metadata": {}, "outputs": [], "source": [ @@ -944,7 +901,7 @@ }, { "cell_type": "markdown", - "id": "55", + "id": "52", "metadata": {}, "source": [ "### Creating the lookup table for ODIN\n", @@ -957,27 +914,22 @@ { "cell_type": "code", "execution_count": null, - "id": "56", + "id": "53", "metadata": {}, "outputs": [], "source": [ - "lut_wf = LookupTableWorkflow(use_simulation=False)\n", - "lut_wf[DiskChoppers[AnyRun]] = odin_choppers\n", - "lut_wf[SourcePosition] = source_position\n", - "lut_wf[LtotalRange] = (\n", - " sc.scalar(25.0, unit=\"m\"),\n", - " sc.scalar(65.0, unit=\"m\"),\n", - ")\n", - "lut_wf[PulseStride] = 2\n", + "wf[DiskChoppers[SampleRun]] = odin_choppers\n", + "wf[Position[snx.NXsource, SampleRun]] = source_position\n", + "wf[PulseStride] = 2\n", "\n", - "frames = lut_wf.compute(ChopperFrameSequence)\n", + "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(Ltotal)\n", "fig, ax = at_detector.draw()\n", "\n", - "table = lut_wf.compute(LookupTable)\n", + "table = wf.compute(LookupTable[SampleRun])\n", "\n", "# Overlay LUT prediction on the polygons figure\n", - "da = table.array[\"distance\", 352]\n", + "da = table.array[\"distance\", 552]\n", "ax.plot(\n", " da.coords['event_time_offset'].values / 1000,\n", " da.values,\n", @@ -990,7 +942,7 @@ }, { "cell_type": "markdown", - "id": "57", + "id": "54", "metadata": {}, "source": [ "The final relation between time-of-arrival and wavelength at the detector is represented by the black lines that accurately trace the green polygons\n", @@ -1006,7 +958,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58", + "id": "55", "metadata": {}, "outputs": [], "source": [ @@ -1015,7 +967,7 @@ }, { "cell_type": "markdown", - "id": "59", + "id": "56", "metadata": {}, "source": [ "### Computing wavelengths for ODIN\n", @@ -1026,19 +978,16 @@ { "cell_type": "code", "execution_count": null, - "id": "60", + "id": "57", "metadata": {}, "outputs": [], "source": [ - "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[])\n", - "\n", "wf[RawDetector[SampleRun]] = raw_data\n", "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "wf[NeXusDetectorName] = 'odin_detector'\n", "wf[LookupTableRelativeErrorThreshold] = {'odin_detector': float(\"inf\")}\n", "\n", "wf.visualize(WavelengthDetector[SampleRun])\n", - "wf[LookupTable] = table\n", "\n", "# Compute wavelength of neutron events\n", "wavs = wf.compute(WavelengthDetector[SampleRun])\n", diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index 155bace4d..89423796b 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -771,7 +771,7 @@ def make_wavelength_lut_from_polygons( - table.coords["distance"][0], time_resolution=table.coords["event_time_offset"][1] - table.coords["event_time_offset"][0], - # TODO: Do we still want to store the chopper information in the lookup table? + # TODO: Do we still want to store the chopper info in the lookup table? ) ) @@ -791,7 +791,7 @@ def default_parameters() -> dict: PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), PulseStride: 1, DistanceResolution: sc.scalar(0.1, unit="m"), - TimeResolution: sc.scalar(250.0, unit='us'), + TimeResolution: sc.scalar(50.0, unit='us'), SourceBounds: SourceBounds( time=(sc.scalar(0.0, unit='ms'), sc.scalar(5.0, unit='ms')), wavelength=( @@ -799,6 +799,7 @@ def default_parameters() -> dict: sc.scalar(15.0, unit='angstrom'), ), ), + LtotalRange: (sc.scalar(5.0, unit='m'), sc.scalar(180.0, unit='m')), } From 933393014148a1e6be2b72067f57def9770ad917 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Fri, 22 May 2026 14:10:47 +0200 Subject: [PATCH 4/6] get ltotal range from component ltotal --- .../essreduce/src/ess/reduce/unwrap/lut.py | 145 ++++++++++++++---- .../src/ess/reduce/unwrap/to_wavelength.py | 4 +- .../essreduce/src/ess/reduce/unwrap/types.py | 4 +- 3 files changed, 117 insertions(+), 36 deletions(-) diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index 89423796b..07470c78b 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -16,8 +16,8 @@ from scippneutron.chopper import DiskChopper from scippneutron.tof import chopper_cascade -from ..nexus.types import DiskChoppers, Position, RunType -from .types import LookupTable, Lut +from ..nexus.types import Component, DiskChoppers, MonitorType, Position, RunType +from .types import DetectorLtotal, LookupTable, Lut, MonitorLtotal @dataclass @@ -109,18 +109,26 @@ class SimulationResults( This is typically a large number, e.g., 1e6 or 1e7. """ -LtotalRange = NewType("LtotalRange", tuple[sc.Variable, sc.Variable]) -""" -Range (min, max) of the total length of the flight path from the source to the detector. -This is used to create the lookup table to compute the neutron time-of-flight. -Note that the resulting table will extend slightly beyond this range, as the supplied -range is not necessarily a multiple of the distance resolution. - -Note also that the range of total flight paths is supplied manually to the workflow -instead of being read from the input data, as it allows us to compute the expensive part -of the workflow in advance (the lookup table) and does not need to be repeated for each -run, or for new data coming in in the case of live data collection. -""" + +class LtotalRange( + sl.Scope[RunType, Component, tuple[sc.Variable, sc.Variable]], + tuple[sc.Variable, sc.Variable], +): + """ + Range (min, max) of the total length of the flight path from the source to the + detector. + This is used to create the lookup table to compute the neutron time-of-flight. + Note that the resulting table will extend slightly beyond this range, as the + supplied + range is not necessarily a multiple of the distance resolution. + + Note also that the range of total flight paths is supplied manually to the + workflow instead of being read from the input data, as it allows us to compute + the expensive part of the workflow in advance (the lookup table) and does not + need to be repeated for each run, or for new data coming in in the case of live + data collection. + """ + DistanceResolution = NewType("DistanceResolution", sc.Variable) """ @@ -252,14 +260,14 @@ def _compute_mean_wavelength( return mean_wavelength -def make_wavelength_lookup_table( +def _make_wavelength_lookup_table_from_simulation( simulation: SimulationResults[RunType], ltotal_range: LtotalRange, distance_resolution: DistanceResolution, time_resolution: TimeResolution, pulse_period: PulsePeriod, pulse_stride: PulseStride, -) -> LookupTable[RunType]: +) -> Lut: """ Compute a lookup table for wavelength as a function of distance and time-of-arrival. @@ -669,14 +677,14 @@ def compute_frame_sequence( return ChopperFrameSequence[RunType](frames) -def make_wavelength_lut_from_polygons( - ltotal_range: LtotalRange, - distance_resolution: DistanceResolution, - time_resolution: TimeResolution, - pulse_period: PulsePeriod, - pulse_stride: PulseStride, +def _make_wavelength_lut_from_polygons( + ltotal_range: tuple[sc.Variable, sc.Variable], + distance_resolution: sc.Variable, + time_resolution: sc.Variable, + pulse_period: sc.Variable, + pulse_stride: int, frames: ChopperFrameSequence, -) -> LookupTable[RunType]: +) -> Lut: """ Compute a lookup table for wavelength as a function of distance and time-of-arrival. @@ -762,16 +770,85 @@ def make_wavelength_lut_from_polygons( coords={"distance": distances, "event_time_offset": time_edges}, ) - return LookupTable[RunType]( - Lut( - array=table, + return Lut( + array=table, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + distance_resolution=table.coords["distance"][1] - table.coords["distance"][0], + time_resolution=table.coords["event_time_offset"][1] + - table.coords["event_time_offset"][0], + # TODO: Do we still want to store the chopper info in the lookup table? + ) + + +def _ltotal_range_from_ltotal(ltotal: sc.Variable) -> tuple[sc.Variable, sc.Variable]: + return (ltotal.min(), ltotal.max()) + + +def ltotal_range_from_ltotal_detector( + ltotal: DetectorLtotal[RunType], +) -> LtotalRange[RunType, snx.NXdetector]: + """ + Compute the range of total flight path lengths from the source to the detector from + the ltotal variable in the input data for the detector workflow. + """ + return LtotalRange[RunType, snx.NXdetector](_ltotal_range_from_ltotal(ltotal)) + + +def ltotal_range_from_ltotal_monitor( + ltotal: MonitorLtotal[RunType], +) -> LtotalRange[RunType, MonitorType]: + """ + Compute the range of total flight path lengths from the source to the detector from + the ltotal variable in the input data for the monitor workflow. + """ + return LtotalRange[RunType, MonitorType](_ltotal_range_from_ltotal(ltotal)) + + +def make_wavelength_lut_from_polygons_detector( + ltotal_range: LtotalRange[RunType, snx.NXdetector], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + frames: ChopperFrameSequence, +) -> LookupTable[RunType, snx.NXdetector]: + """ + Wrapper around _make_wavelength_lut_from_polygons to specify the Component as + snx.NXdetector, for use in the detector workflow. + """ + return LookupTable[RunType, snx.NXdetector]( + _make_wavelength_lut_from_polygons( + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + frames=frames, + ) + ) + + +def make_wavelength_lut_from_polygons_monitor( + ltotal_range: LtotalRange[RunType, MonitorType], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + frames: ChopperFrameSequence, +) -> LookupTable[RunType, MonitorType]: + """ + Wrapper around _make_wavelength_lut_from_polygons to specify the Component as + snx.NXmonitor, for use in the monitor workflow. + """ + return LookupTable[RunType, MonitorType]( + _make_wavelength_lut_from_polygons( + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, pulse_period=pulse_period, pulse_stride=pulse_stride, - distance_resolution=table.coords["distance"][1] - - table.coords["distance"][0], - time_resolution=table.coords["event_time_offset"][1] - - table.coords["event_time_offset"][0], - # TODO: Do we still want to store the chopper info in the lookup table? + frames=frames, ) ) @@ -783,7 +860,11 @@ def providers() -> tuple[Callable]: to compute the lookup table is expensive and not something we want to do by default. """ - return (make_wavelength_lut_from_polygons, compute_frame_sequence) + return ( + make_wavelength_lut_from_polygons_detector, + make_wavelength_lut_from_polygons_monitor, + compute_frame_sequence, + ) def default_parameters() -> dict: diff --git a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py index 93a832ec7..5e379b485 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py +++ b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py @@ -462,7 +462,7 @@ def _mask_large_uncertainty_in_lut(table: Lut, error_threshold: float) -> Lut: def mask_large_uncertainty_in_lut_detector( - table: LookupTable[RunType], + table: LookupTable[RunType, snx.NXdetector], error_threshold: LookupTableRelativeErrorThreshold, detector_name: NeXusDetectorName, ) -> ErrorLimitedLookupTable[RunType, snx.NXdetector]: @@ -488,7 +488,7 @@ def mask_large_uncertainty_in_lut_detector( def mask_large_uncertainty_in_lut_monitor( - table: LookupTable[RunType], + table: LookupTable[RunType, MonitorType], error_threshold: LookupTableRelativeErrorThreshold, monitor_name: NeXusName[MonitorType], ) -> ErrorLimitedLookupTable[RunType, MonitorType]: diff --git a/packages/essreduce/src/ess/reduce/unwrap/types.py b/packages/essreduce/src/ess/reduce/unwrap/types.py index babcccd5f..354f0c7eb 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/types.py +++ b/packages/essreduce/src/ess/reduce/unwrap/types.py @@ -45,9 +45,9 @@ def plot(self, *args, **kwargs) -> Any: return self.array.plot(*args, **kwargs) -class LookupTable(sl.Scope[RunType, Lut], Lut): +class LookupTable(sl.Scope[RunType, Component, Lut], Lut): """Lookup table giving wavelength as a function of distance and - ``event_time_offset``.""" + ``event_time_offset`` for each beamline component (detector, monitor).""" class ErrorLimitedLookupTable(sl.Scope[RunType, Component, Lut], Lut): From 97c0b9baa3191634300d137ea3018651cf8a6c30 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Fri, 22 May 2026 23:47:07 +0200 Subject: [PATCH 5/6] make workflow work with lut per component --- .../user-guide/unwrap/analytical-unwrap.ipynb | 32 ++++---- .../essreduce/src/ess/reduce/unwrap/lut.py | 76 ++++++++++++++++--- .../src/ess/reduce/unwrap/to_wavelength.py | 12 ++- 3 files changed, 86 insertions(+), 34 deletions(-) diff --git a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb index 68c872fa7..617187431 100644 --- a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb +++ b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb @@ -33,7 +33,7 @@ "import scipp as sc\n", "import scippnexus as snx\n", "from scippneutron.chopper import DiskChopper\n", - "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName, Position\n", + "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName, Position, FrameMonitor0\n", "from ess.reduce.unwrap import *" ] }, @@ -187,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")" + "Ltotal = sc.scalar(60.0, unit=\"m\")" ] }, { @@ -305,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[])\n", + "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0])\n", "\n", "wf[RawDetector[SampleRun]] = raw_data\n", "wf[DetectorLtotal[SampleRun]] = Ltotal\n", @@ -313,7 +313,6 @@ "wf[NeXusDetectorName] = 'dream_detector'\n", "wf[DiskChoppers[SampleRun]] = disk_choppers\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': float(\"inf\")}\n", - "wf[LtotalRange] = (sc.scalar(5.0, unit='m'), sc.scalar(80.0, unit='m'))\n", "\n", "wf.visualize(WavelengthDetector[SampleRun])" ] @@ -378,10 +377,10 @@ "metadata": {}, "outputs": [], "source": [ - "table = wf.compute(LookupTable[SampleRun])\n", + "table = wf.compute(LookupTable[SampleRun, snx.NXdetector])\n", "\n", "# Overlay LUT prediction on the polygons figure\n", - "da = table.array[\"distance\", 552]\n", + "da = table.array[\"distance\", 2]\n", "ax.plot(\n", " da.coords['event_time_offset'].values / 1000,\n", " da.values,\n", @@ -498,7 +497,7 @@ "metadata": {}, "outputs": [], "source": [ - "Ltotal = sc.array(dims=[\"detector_number\"], values=[77.675, 76.0], unit=\"m\")\n", + "Ltotal = sc.array(dims=[\"detector_number\"], values=[59.0, 60.0], unit=\"m\")\n", "monitors = {f\"detector{i}\": ltot for i, ltot in enumerate(Ltotal)}\n", "\n", "ess_beamline = FakeBeamline(\n", @@ -618,7 +617,7 @@ ")\n", "\n", "# Go back to a single detector pixel\n", - "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")\n", + "Ltotal = sc.scalar(76.0, unit=\"m\")\n", "\n", "ess_beamline = FakeBeamline(\n", " choppers=disk_choppers,\n", @@ -650,7 +649,8 @@ "outputs": [], "source": [ "# Update workflow\n", - "wf[DiskChoppers[AnyRun]] = disk_choppers\n", + "wf[DiskChoppers[SampleRun]] = disk_choppers\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(dist)\n", @@ -682,9 +682,8 @@ "metadata": {}, "outputs": [], "source": [ - "table = wf.compute(LookupTable[SampleRun])\n", - "table.plot(ymin=65) / (sc.stddevs(table.array) / sc.values(table.array)).plot(\n", - " norm=\"linear\", ymin=55, vmax=0.05\n", + "table = wf.compute(LookupTable[SampleRun, snx.NXdetector])\n", + "table.plot() / (sc.stddevs(table.array) / sc.values(table.array)).plot(\n", ")" ] }, @@ -711,7 +710,7 @@ "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': 0.02}\n", "\n", "masked_table = wf.compute(ErrorLimitedLookupTable[SampleRun, snx.NXdetector])\n", - "masked_table.plot(ymin=65)" + "masked_table.plot()" ] }, { @@ -736,7 +735,6 @@ "outputs": [], "source": [ "wf[RawDetector[SampleRun]] = ess_beamline.get_monitor(\"detector\")[0]\n", - "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", "# Compute wavelength\n", "wav_wfm = wf.compute(WavelengthDetector[SampleRun])\n", @@ -920,16 +918,17 @@ "source": [ "wf[DiskChoppers[SampleRun]] = odin_choppers\n", "wf[Position[snx.NXsource, SampleRun]] = source_position\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "wf[PulseStride] = 2\n", "\n", "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(Ltotal)\n", "fig, ax = at_detector.draw()\n", "\n", - "table = wf.compute(LookupTable[SampleRun])\n", + "table = wf.compute(LookupTable[SampleRun, snx.NXdetector])\n", "\n", "# Overlay LUT prediction on the polygons figure\n", - "da = table.array[\"distance\", 552]\n", + "da = table.array[\"distance\", 2]\n", "ax.plot(\n", " da.coords['event_time_offset'].values / 1000,\n", " da.values,\n", @@ -983,7 +982,6 @@ "outputs": [], "source": [ "wf[RawDetector[SampleRun]] = raw_data\n", - "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "wf[NeXusDetectorName] = 'odin_detector'\n", "wf[LookupTableRelativeErrorThreshold] = {'odin_detector': float(\"inf\")}\n", "\n", diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index 07470c78b..5baa83bdd 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -260,13 +260,13 @@ def _compute_mean_wavelength( return mean_wavelength -def _make_wavelength_lookup_table_from_simulation( - simulation: SimulationResults[RunType], - ltotal_range: LtotalRange, - distance_resolution: DistanceResolution, - time_resolution: TimeResolution, - pulse_period: PulsePeriod, - pulse_stride: PulseStride, +def _make_wavelength_lut_from_simulation( + simulation: SimulationResultsBaseClass, + ltotal_range: tuple[sc.Variable, sc.Variable], + distance_resolution: sc.Variable, + time_resolution: sc.Variable, + pulse_period: sc.Variable, + pulse_stride: int, ) -> Lut: """ Compute a lookup table for wavelength as a function of distance and @@ -430,6 +430,54 @@ def _make_wavelength_lookup_table_from_simulation( ) +def make_wavelength_lut_from_simulation_detector( + simulation: SimulationResults[RunType], + ltotal_range: LtotalRange[RunType, snx.NXdetector], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, +) -> LookupTable[RunType, snx.NXdetector]: + """ + Wrapper around _make_wavelength_lut_from_simulation to specify the Component as + snx.NXdetector, for use in the detector workflow. + """ + return LookupTable[RunType, snx.NXdetector]( + _make_wavelength_lut_from_simulation( + simulation=simulation, + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + ) + ) + + +def make_wavelength_lut_from_simulation_monitor( + simulation: SimulationResults[RunType], + ltotal_range: LtotalRange[RunType, MonitorType], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, +) -> LookupTable[RunType, MonitorType]: + """ + Wrapper around _make_wavelength_lut_from_simulation to specify the Component as + snx.NXmonitor, for use in the monitor workflow. + """ + return LookupTable[RunType, MonitorType]( + _make_wavelength_lut_from_simulation( + simulation=simulation, + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + ) + ) + + def _to_component_reading(component) -> BeamlineComponentReading: events = component.data.squeeze().flatten(to='event') sel = sc.full(value=True, sizes=events.sizes) @@ -796,7 +844,7 @@ def ltotal_range_from_ltotal_detector( def ltotal_range_from_ltotal_monitor( - ltotal: MonitorLtotal[RunType], + ltotal: MonitorLtotal[RunType, MonitorType], ) -> LtotalRange[RunType, MonitorType]: """ Compute the range of total flight path lengths from the source to the detector from @@ -861,6 +909,8 @@ def providers() -> tuple[Callable]: default. """ return ( + ltotal_range_from_ltotal_detector, + ltotal_range_from_ltotal_monitor, make_wavelength_lut_from_polygons_detector, make_wavelength_lut_from_polygons_monitor, compute_frame_sequence, @@ -880,7 +930,7 @@ def default_parameters() -> dict: sc.scalar(15.0, unit='angstrom'), ), ), - LtotalRange: (sc.scalar(5.0, unit='m'), sc.scalar(180.0, unit='m')), + # LtotalRange: (sc.scalar(5.0, unit='m'), sc.scalar(180.0, unit='m')), } @@ -902,7 +952,13 @@ def LookupTableWorkflow(use_simulation: bool = True): """ default_params = default_parameters() if use_simulation: - provs = (make_wavelength_lookup_table, simulate_chopper_cascade_using_tof) + provs = ( + ltotal_range_from_ltotal_detector, + ltotal_range_from_ltotal_monitor, + make_wavelength_lut_from_simulation_detector, + make_wavelength_lut_from_simulation_monitor, + simulate_chopper_cascade_using_tof, + ) default_params.update( { NumberOfSimulatedNeutrons: 1_000_000, diff --git a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py index 5e379b485..bf56f7307 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py +++ b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py @@ -451,13 +451,11 @@ def _mask_large_uncertainty_in_lut(table: Lut, error_threshold: float) -> Lut: da = table.array relative_error = sc.stddevs(da.data) / sc.values(da.data) mask = relative_error > sc.scalar(error_threshold) - return LookupTable[RunType]( - Lut( - **{ - **asdict(table), - "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), - } - ) + return Lut( + **{ + **asdict(table), + "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), + } ) From e2ebb99c14bee6ae890dfb25fa830067c1133761 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Fri, 22 May 2026 23:50:15 +0200 Subject: [PATCH 6/6] lint --- .../essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb | 2 +- packages/essreduce/src/ess/reduce/unwrap/lut.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb index 617187431..2b49ac125 100644 --- a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb +++ b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb @@ -33,7 +33,7 @@ "import scipp as sc\n", "import scippnexus as snx\n", "from scippneutron.chopper import DiskChopper\n", - "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName, Position, FrameMonitor0\n", + "from ess.reduce.nexus.types import RawDetector, SampleRun, NeXusDetectorName, Position, FrameMonitor0\n", "from ess.reduce.unwrap import *" ] }, diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index 5baa83bdd..9d0afd52b 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -930,7 +930,6 @@ def default_parameters() -> dict: sc.scalar(15.0, unit='angstrom'), ), ), - # LtotalRange: (sc.scalar(5.0, unit='m'), sc.scalar(180.0, unit='m')), }