diff --git a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb index a6e3dfc19..2b49ac125 100644 --- a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb +++ b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb @@ -33,7 +33,7 @@ "import scipp as sc\n", "import scippnexus as snx\n", "from scippneutron.chopper import DiskChopper\n", - "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName\n", + "from ess.reduce.nexus.types import RawDetector, SampleRun, NeXusDetectorName, Position, FrameMonitor0\n", "from ess.reduce.unwrap import *" ] }, @@ -187,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")" + "Ltotal = sc.scalar(60.0, unit=\"m\")" ] }, { @@ -293,7 +293,7 @@ "source": [ "### Computing neutron wavelengths\n", "\n", - "Next, we use a workflow that provides an estimate of the neutron wavelength as a function of neutron time-of-arrival.\n", + "Next, we use a workflow that computes the neutron wavelength from the neutron time-of-arrival.\n", "\n", "#### Setting up the workflow" ] @@ -305,11 +305,13 @@ "metadata": {}, "outputs": [], "source": [ - "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[])\n", + "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[FrameMonitor0])\n", "\n", "wf[RawDetector[SampleRun]] = raw_data\n", "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "wf[Position[snx.NXsource, SampleRun]] = source_position\n", "wf[NeXusDetectorName] = 'dream_detector'\n", + "wf[DiskChoppers[SampleRun]] = disk_choppers\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': float(\"inf\")}\n", "\n", "wf.visualize(WavelengthDetector[SampleRun])" @@ -319,77 +321,36 @@ "cell_type": "markdown", "id": "22", "metadata": {}, - "source": [ - "By default, the workflow tries to load a `LookupTable` from a file.\n", - "\n", - "In this notebook, instead of using such a pre-made file,\n", - "we will build our own lookup table from the chopper information and apply it to the workflow." - ] - }, - { - "cell_type": "markdown", - "id": "23", - "metadata": {}, - "source": [ - "#### Building the wavelength lookup table\n", - "\n", - "We use [`scippneutron.tof.chopper_cascade`](https://scipp.github.io/scippneutron/user-guide/chopper/chopper-cascade.html) module to propagate a pulse of neutrons through the chopper system to the detectors,\n", - "and predict the most likely neutron wavelength for a given time-of-arrival and distance from source.\n", - "\n", - "From this,\n", - "we build a lookup table on which bilinear interpolation is used to compute a wavelength for every neutron event." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "lut_wf = LookupTableWorkflow(use_simulation=False)\n", - "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", - "lut_wf[SourcePosition] = source_position\n", - "lut_wf[LtotalRange] = (\n", - " sc.scalar(25.0, unit=\"m\"),\n", - " sc.scalar(80.0, unit=\"m\"),\n", - ")\n", - "lut_wf.visualize(LookupTable)" - ] - }, - { - "cell_type": "markdown", - "id": "25", - "metadata": {}, "source": [ "#### Inspecting the lookup table\n", "\n", - "The workflow first runs a calculation propagating a pulse of neutrons (represented by a polygon in time and wavelength space),\n", + "The workflow first uses the [`scippneutron.tof.chopper_cascade`](https://scipp.github.io/scippneutron/user-guide/chopper/chopper-cascade.html)\n", + "module to propagate a pulse of neutrons (represented by a polygon in time and wavelength space)\n", "through a chopper cascade defined by the chopper parameters above.\n", "\n", "This can be used to create a figure displaying the neutron wavelengths,\n", "as a function of arrival time at the detector.\n", "\n", - "This is the basis for creating our lookup table." + "This is the basis for creating our lookup table, on which bilinear interpolation is used further down to compute a wavelength for every neutron event." ] }, { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "23", "metadata": {}, "outputs": [], "source": [ "dist = sc.scalar(60.0, unit='m')\n", "\n", - "frames = lut_wf.compute(ChopperFrameSequence)\n", + "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(dist)\n", "fig, ax = at_detector.draw()" ] }, { "cell_type": "markdown", - "id": "27", + "id": "24", "metadata": {}, "source": [ "The source pulse is defined as spanning 0-5 ms in time, and 0-15 Å in wavelength,\n", @@ -412,14 +373,14 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "25", "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(LookupTable)\n", + "table = wf.compute(LookupTable[SampleRun, snx.NXdetector])\n", "\n", "# Overlay LUT prediction on the polygons figure\n", - "da = table.array[\"distance\", 352]\n", + "da = table.array[\"distance\", 2]\n", "ax.plot(\n", " da.coords['event_time_offset'].values / 1000,\n", " da.values,\n", @@ -433,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -443,7 +404,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "27", "metadata": {}, "source": [ "The full table covers a range of distances, and looks like" @@ -452,7 +413,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -461,24 +422,21 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "29", "metadata": {}, "source": [ "#### Computing a wavelength coordinate\n", "\n", - "We will now update our workflow, and use it to obtain our event data with a wavelength coordinate:" + "We will now use our workflow once again to obtain our event data with a wavelength coordinate (a step further down in the pipeline):" ] }, { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "30", "metadata": {}, "outputs": [], "source": [ - "# Set the computed lookup table onto the original workflow\n", - "wf[LookupTable] = table\n", - "\n", "# Compute wavelength of neutron events\n", "wavs = wf.compute(WavelengthDetector[SampleRun])\n", "edges = sc.linspace(\"wavelength\", 0.8, 4.6, 201, unit=\"angstrom\")\n", @@ -489,7 +447,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "31", "metadata": {}, "source": [ "#### Comparing to the ground truth\n", @@ -501,7 +459,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -519,7 +477,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "33", "metadata": {}, "source": [ "### Multiple detector pixels\n", @@ -535,11 +493,11 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "34", "metadata": {}, "outputs": [], "source": [ - "Ltotal = sc.array(dims=[\"detector_number\"], values=[77.675, 76.0], unit=\"m\")\n", + "Ltotal = sc.array(dims=[\"detector_number\"], values=[59.0, 60.0], unit=\"m\")\n", "monitors = {f\"detector{i}\": ltot for i, ltot in enumerate(Ltotal)}\n", "\n", "ess_beamline = FakeBeamline(\n", @@ -553,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "38", + "id": "35", "metadata": {}, "source": [ "Our raw data has now a `detector_number` dimension of length 2.\n", @@ -564,7 +522,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -579,7 +537,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "37", "metadata": {}, "source": [ "Computing wavelength is done in the same way as above.\n", @@ -589,7 +547,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -623,7 +581,7 @@ }, { "cell_type": "markdown", - "id": "42", + "id": "39", "metadata": {}, "source": [ "### Handling time overlap between subframes\n", @@ -643,7 +601,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -659,7 +617,7 @@ ")\n", "\n", "# Go back to a single detector pixel\n", - "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")\n", + "Ltotal = sc.scalar(76.0, unit=\"m\")\n", "\n", "ess_beamline = FakeBeamline(\n", " choppers=disk_choppers,\n", @@ -674,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "44", + "id": "41", "metadata": {}, "source": [ "We can now see that there is no longer a gap between the two frames at the center of each pulse (green region).\n", @@ -686,14 +644,15 @@ { "cell_type": "code", "execution_count": null, - "id": "45", + "id": "42", "metadata": {}, "outputs": [], "source": [ "# Update workflow\n", - "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", + "wf[DiskChoppers[SampleRun]] = disk_choppers\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", - "frames = lut_wf.compute(ChopperFrameSequence)\n", + "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(dist)\n", "fig, ax = at_detector.draw()\n", "ax.set(xlim=(36, 44), ylim=(2, 3))" @@ -701,7 +660,7 @@ }, { "cell_type": "markdown", - "id": "46", + "id": "43", "metadata": {}, "source": [ "The data in the lookup table contains both the mean wavelength for each distance and time-of-arrival bin,\n", @@ -719,19 +678,18 @@ { "cell_type": "code", "execution_count": null, - "id": "47", + "id": "44", "metadata": {}, "outputs": [], "source": [ - "table = lut_wf.compute(LookupTable)\n", - "table.plot(ymin=65) / (sc.stddevs(table.array) / sc.values(table.array)).plot(\n", - " norm=\"linear\", ymin=55, vmax=0.05\n", + "table = wf.compute(LookupTable[SampleRun, snx.NXdetector])\n", + "table.plot() / (sc.stddevs(table.array) / sc.values(table.array)).plot(\n", ")" ] }, { "cell_type": "markdown", - "id": "48", + "id": "45", "metadata": {}, "source": [ "The workflow has a parameter which is used to mask out regions where the standard deviation is above a certain threshold.\n", @@ -745,21 +703,19 @@ { "cell_type": "code", "execution_count": null, - "id": "49", + "id": "46", "metadata": {}, "outputs": [], "source": [ - "wf[LookupTable] = table\n", - "\n", "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': 0.02}\n", "\n", - "masked_table = wf.compute(ErrorLimitedLookupTable[snx.NXdetector])\n", - "masked_table.plot(ymin=65)" + "masked_table = wf.compute(ErrorLimitedLookupTable[SampleRun, snx.NXdetector])\n", + "masked_table.plot()" ] }, { "cell_type": "markdown", - "id": "50", + "id": "47", "metadata": {}, "source": [ "We can now see that the central region is masked out.\n", @@ -774,12 +730,11 @@ { "cell_type": "code", "execution_count": null, - "id": "51", + "id": "48", "metadata": {}, "outputs": [], "source": [ "wf[RawDetector[SampleRun]] = ess_beamline.get_monitor(\"detector\")[0]\n", - "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "\n", "# Compute wavelength\n", "wav_wfm = wf.compute(WavelengthDetector[SampleRun])\n", @@ -799,7 +754,7 @@ }, { "cell_type": "markdown", - "id": "52", + "id": "49", "metadata": {}, "source": [ "## The ODIN instrument\n", @@ -819,7 +774,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53", + "id": "50", "metadata": {}, "outputs": [], "source": [ @@ -932,7 +887,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54", + "id": "51", "metadata": {}, "outputs": [], "source": [ @@ -944,7 +899,7 @@ }, { "cell_type": "markdown", - "id": "55", + "id": "52", "metadata": {}, "source": [ "### Creating the lookup table for ODIN\n", @@ -957,27 +912,23 @@ { "cell_type": "code", "execution_count": null, - "id": "56", + "id": "53", "metadata": {}, "outputs": [], "source": [ - "lut_wf = LookupTableWorkflow(use_simulation=False)\n", - "lut_wf[DiskChoppers[AnyRun]] = odin_choppers\n", - "lut_wf[SourcePosition] = source_position\n", - "lut_wf[LtotalRange] = (\n", - " sc.scalar(25.0, unit=\"m\"),\n", - " sc.scalar(65.0, unit=\"m\"),\n", - ")\n", - "lut_wf[PulseStride] = 2\n", + "wf[DiskChoppers[SampleRun]] = odin_choppers\n", + "wf[Position[snx.NXsource, SampleRun]] = source_position\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "wf[PulseStride] = 2\n", "\n", - "frames = lut_wf.compute(ChopperFrameSequence)\n", + "frames = wf.compute(ChopperFrameSequence[SampleRun])\n", "at_detector = frames.propagate_to(Ltotal)\n", "fig, ax = at_detector.draw()\n", "\n", - "table = lut_wf.compute(LookupTable)\n", + "table = wf.compute(LookupTable[SampleRun, snx.NXdetector])\n", "\n", "# Overlay LUT prediction on the polygons figure\n", - "da = table.array[\"distance\", 352]\n", + "da = table.array[\"distance\", 2]\n", "ax.plot(\n", " da.coords['event_time_offset'].values / 1000,\n", " da.values,\n", @@ -990,7 +941,7 @@ }, { "cell_type": "markdown", - "id": "57", + "id": "54", "metadata": {}, "source": [ "The final relation between time-of-arrival and wavelength at the detector is represented by the black lines that accurately trace the green polygons\n", @@ -1006,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58", + "id": "55", "metadata": {}, "outputs": [], "source": [ @@ -1015,7 +966,7 @@ }, { "cell_type": "markdown", - "id": "59", + "id": "56", "metadata": {}, "source": [ "### Computing wavelengths for ODIN\n", @@ -1026,19 +977,15 @@ { "cell_type": "code", "execution_count": null, - "id": "60", + "id": "57", "metadata": {}, "outputs": [], "source": [ - "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[])\n", - "\n", "wf[RawDetector[SampleRun]] = raw_data\n", - "wf[DetectorLtotal[SampleRun]] = Ltotal\n", "wf[NeXusDetectorName] = 'odin_detector'\n", "wf[LookupTableRelativeErrorThreshold] = {'odin_detector': float(\"inf\")}\n", "\n", "wf.visualize(WavelengthDetector[SampleRun])\n", - "wf[LookupTable] = table\n", "\n", "# Compute wavelength of neutron events\n", "wavs = wf.compute(WavelengthDetector[SampleRun])\n", diff --git a/packages/essreduce/src/ess/reduce/nexus/workflow.py b/packages/essreduce/src/ess/reduce/nexus/workflow.py index dcd0a7988..6e320d53f 100644 --- a/packages/essreduce/src/ess/reduce/nexus/workflow.py +++ b/packages/essreduce/src/ess/reduce/nexus/workflow.py @@ -15,7 +15,7 @@ import scippnexus as snx from scipp.constants import g from scipp.core import label_based_index_to_positional_index -from scippneutron.chopper import extract_chopper_from_nexus +from scippneutron.chopper import DiskChopper, extract_chopper_from_nexus from scippneutron.metadata import RadiationProbe, SourceType from . import _nexus_loader as nexus @@ -26,6 +26,7 @@ Component, DetectorBankSizes, DetectorPositionOffset, + DiskChoppers, DynamicPosition, EmptyDetector, EmptyMonitor, @@ -537,7 +538,7 @@ def assemble_monitor_data( return RawMonitor[RunType, MonitorType](_add_variances(da)) -def parse_disk_choppers( +def parse_raw_choppers( choppers: AllNeXusComponents[snx.NXdisk_chopper, RunType], ) -> RawChoppers[RunType]: """Convert the NeXus representation of a chopper to ours. @@ -559,6 +560,17 @@ def parse_disk_choppers( ) +def to_disk_choppers(choppers: RawChoppers[RunType]) -> DiskChoppers[RunType]: + disk_choppers = { + # If there is no beam_position, we set it to 0 by default. + key: DiskChopper.from_nexus( + {**{'beam_position': sc.scalar(0.0, unit='deg')}, **ch} + ) + for key, ch in choppers.items() + } + return DiskChoppers[RunType](disk_choppers) + + def load_proton_charge( parent_location: NeXusComponentLocationSpec[ProductionInfo, RunType], interval: TimeInterval[RunType], @@ -789,7 +801,7 @@ def load_source_metadata_from_nexus( assemble_detector_data, ) -_chopper_providers = (parse_disk_choppers,) +_chopper_providers = (parse_raw_choppers, to_disk_choppers) _metadata_providers = ( load_beamline_metadata_from_nexus, diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index 93bdbebde..9d0afd52b 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -5,16 +5,19 @@ """ import warnings +from collections.abc import Callable from dataclasses import dataclass from typing import NewType import numpy as np import sciline as sl import scipp as sc +import scippnexus as snx +from scippneutron.chopper import DiskChopper from scippneutron.tof import chopper_cascade -from ..nexus.types import AnyRun, DiskChoppers -from .types import LookupTable +from ..nexus.types import Component, DiskChoppers, MonitorType, Position, RunType +from .types import DetectorLtotal, LookupTable, Lut, MonitorLtotal @dataclass @@ -52,7 +55,7 @@ def __post_init__(self): @dataclass -class SimulationResults: +class SimulationResultsBaseClass: """ Results of a time-of-flight simulation used to create a lookup table. It should contain readings at various positions along the beamline, e.g., at @@ -73,7 +76,31 @@ class SimulationResults: """ readings: dict[str, BeamlineComponentReading] - choppers: DiskChoppers[AnyRun] | None = None + choppers: dict[str, DiskChopper] | None = None + + +class SimulationResults( + sl.Scope[RunType, SimulationResultsBaseClass], + SimulationResultsBaseClass, +): + """ + Results of a time-of-flight simulation used to create a lookup table. + It should contain readings at various positions along the beamline, e.g., at + the source and after each chopper. + It also contains the chopper parameters used in the simulation, so it can be + determined if this simulation is compatible with a given experiment. + + Parameters + ---------- + readings: + A dict of :class:`BeamlineComponentReading` objects representing the readings at + various positions along the beamline. The keys in the dict should correspond to + the names of the components (e.g., 'source', 'chopper1', etc.). + choppers: + The chopper parameters used in the simulation (if any). These are used to verify + that the simulation is compatible with a given experiment (comparing chopper + openings, frequencies, phases, etc.). + """ NumberOfSimulatedNeutrons = NewType("NumberOfSimulatedNeutrons", int) @@ -82,18 +109,26 @@ class SimulationResults: This is typically a large number, e.g., 1e6 or 1e7. """ -LtotalRange = NewType("LtotalRange", tuple[sc.Variable, sc.Variable]) -""" -Range (min, max) of the total length of the flight path from the source to the detector. -This is used to create the lookup table to compute the neutron time-of-flight. -Note that the resulting table will extend slightly beyond this range, as the supplied -range is not necessarily a multiple of the distance resolution. - -Note also that the range of total flight paths is supplied manually to the workflow -instead of being read from the input data, as it allows us to compute the expensive part -of the workflow in advance (the lookup table) and does not need to be repeated for each -run, or for new data coming in in the case of live data collection. -""" + +class LtotalRange( + sl.Scope[RunType, Component, tuple[sc.Variable, sc.Variable]], + tuple[sc.Variable, sc.Variable], +): + """ + Range (min, max) of the total length of the flight path from the source to the + detector. + This is used to create the lookup table to compute the neutron time-of-flight. + Note that the resulting table will extend slightly beyond this range, as the + supplied + range is not necessarily a multiple of the distance resolution. + + Note also that the range of total flight paths is supplied manually to the + workflow instead of being read from the input data, as it allows us to compute + the expensive part of the workflow in advance (the lookup table) and does not + need to be repeated for each run, or for new data coming in in the case of live + data collection. + """ + DistanceResolution = NewType("DistanceResolution", sc.Variable) """ @@ -151,11 +186,13 @@ class SourceBounds: """Wavelength range (min, max) of the neutrons in the source pulse.""" -ChopperFrameSequence = NewType("ChopperFrameSequence", chopper_cascade.FrameSequence) -""" -Sequence of chopper frames used to compute the wavelength as a function of distance and -event_time_offset in the lookup table. -""" +class ChopperFrameSequence( + sl.Scope[RunType, chopper_cascade.FrameSequence], chopper_cascade.FrameSequence +): + """ + Sequence of chopper frames used to compute the wavelength as a function of distance + and event_time_offset in the lookup table. + """ def _compute_mean_wavelength( @@ -223,14 +260,14 @@ def _compute_mean_wavelength( return mean_wavelength -def make_wavelength_lookup_table( - simulation: SimulationResults, - ltotal_range: LtotalRange, - distance_resolution: DistanceResolution, - time_resolution: TimeResolution, - pulse_period: PulsePeriod, - pulse_stride: PulseStride, -) -> LookupTable: +def _make_wavelength_lut_from_simulation( + simulation: SimulationResultsBaseClass, + ltotal_range: tuple[sc.Variable, sc.Variable], + distance_resolution: sc.Variable, + time_resolution: sc.Variable, + pulse_period: sc.Variable, + pulse_stride: int, +) -> Lut: """ Compute a lookup table for wavelength as a function of distance and time-of-arrival. @@ -378,7 +415,7 @@ def make_wavelength_lookup_table( }, ) - return LookupTable( + return LookupTable[RunType]( array=table, pulse_period=pulse_period, pulse_stride=pulse_stride, @@ -393,7 +430,55 @@ def make_wavelength_lookup_table( ) -def _to_component_reading(component): +def make_wavelength_lut_from_simulation_detector( + simulation: SimulationResults[RunType], + ltotal_range: LtotalRange[RunType, snx.NXdetector], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, +) -> LookupTable[RunType, snx.NXdetector]: + """ + Wrapper around _make_wavelength_lut_from_simulation to specify the Component as + snx.NXdetector, for use in the detector workflow. + """ + return LookupTable[RunType, snx.NXdetector]( + _make_wavelength_lut_from_simulation( + simulation=simulation, + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + ) + ) + + +def make_wavelength_lut_from_simulation_monitor( + simulation: SimulationResults[RunType], + ltotal_range: LtotalRange[RunType, MonitorType], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, +) -> LookupTable[RunType, MonitorType]: + """ + Wrapper around _make_wavelength_lut_from_simulation to specify the Component as + snx.NXmonitor, for use in the monitor workflow. + """ + return LookupTable[RunType, MonitorType]( + _make_wavelength_lut_from_simulation( + simulation=simulation, + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + ) + ) + + +def _to_component_reading(component) -> BeamlineComponentReading: events = component.data.squeeze().flatten(to='event') sel = sc.full(value=True, sizes=events.sizes) for key in {'blocked_by_others', 'blocked_by_me'} & set(events.masks.keys()): @@ -412,13 +497,13 @@ def _to_component_reading(component): def simulate_chopper_cascade_using_tof( - choppers: DiskChoppers[AnyRun], + choppers: DiskChoppers[RunType], source_position: SourcePosition, neutrons: NumberOfSimulatedNeutrons, pulse_stride: PulseStride, seed: SimulationSeed, facility: SimulationFacility, -) -> SimulationResults: +) -> SimulationResults[RunType]: """ Simulate a pulse of neutrons propagating through a chopper cascade using the ``tof`` package (https://scipp.github.io/tof). @@ -457,12 +542,12 @@ def simulate_chopper_cascade_using_tof( ) sim_readings = {"source": _to_component_reading(source)} if not tof_choppers: - return SimulationResults(readings=sim_readings, choppers=None) + return SimulationResults[RunType](readings=sim_readings, choppers=None) model = tof.Model(source=source, choppers=tof_choppers) results = model.run() for name, ch in results.choppers.items(): sim_readings[name] = _to_component_reading(ch) - return SimulationResults(readings=sim_readings, choppers=choppers) + return SimulationResults[RunType](readings=sim_readings, choppers=choppers) def _polygon_intersections(polygons: list[np.ndarray], x: np.ndarray) -> np.ndarray: @@ -574,11 +659,11 @@ def _estimate_wavelength_by_polygon_centers( def compute_frame_sequence( pulse_period: PulsePeriod, - disk_choppers: DiskChoppers[AnyRun], - source_position: SourcePosition, + disk_choppers: DiskChoppers[RunType], + source_position: Position[snx.NXsource, RunType], source_bounds: SourceBounds, pulse_stride: PulseStride, -) -> ChopperFrameSequence: +) -> ChopperFrameSequence[RunType]: """ Compute the chopper frame sequence for a given set of disk choppers and source pulse parameters. @@ -637,17 +722,17 @@ def compute_frame_sequence( npulses=pulse_stride, ) frames = frames.chop(chops.values()) - return ChopperFrameSequence(frames) + return ChopperFrameSequence[RunType](frames) -def make_wavelength_lut_from_polygons( - ltotal_range: LtotalRange, - distance_resolution: DistanceResolution, - time_resolution: TimeResolution, - pulse_period: PulsePeriod, - pulse_stride: PulseStride, +def _make_wavelength_lut_from_polygons( + ltotal_range: tuple[sc.Variable, sc.Variable], + distance_resolution: sc.Variable, + time_resolution: sc.Variable, + pulse_period: sc.Variable, + pulse_stride: int, frames: ChopperFrameSequence, -) -> LookupTable: +) -> Lut: """ Compute a lookup table for wavelength as a function of distance and time-of-arrival. @@ -733,17 +818,121 @@ def make_wavelength_lut_from_polygons( coords={"distance": distances, "event_time_offset": time_edges}, ) - return LookupTable( + return Lut( array=table, pulse_period=pulse_period, pulse_stride=pulse_stride, distance_resolution=table.coords["distance"][1] - table.coords["distance"][0], time_resolution=table.coords["event_time_offset"][1] - table.coords["event_time_offset"][0], - # TODO: Do we still want to store the chopper information in the lookup table? + # TODO: Do we still want to store the chopper info in the lookup table? ) +def _ltotal_range_from_ltotal(ltotal: sc.Variable) -> tuple[sc.Variable, sc.Variable]: + return (ltotal.min(), ltotal.max()) + + +def ltotal_range_from_ltotal_detector( + ltotal: DetectorLtotal[RunType], +) -> LtotalRange[RunType, snx.NXdetector]: + """ + Compute the range of total flight path lengths from the source to the detector from + the ltotal variable in the input data for the detector workflow. + """ + return LtotalRange[RunType, snx.NXdetector](_ltotal_range_from_ltotal(ltotal)) + + +def ltotal_range_from_ltotal_monitor( + ltotal: MonitorLtotal[RunType, MonitorType], +) -> LtotalRange[RunType, MonitorType]: + """ + Compute the range of total flight path lengths from the source to the detector from + the ltotal variable in the input data for the monitor workflow. + """ + return LtotalRange[RunType, MonitorType](_ltotal_range_from_ltotal(ltotal)) + + +def make_wavelength_lut_from_polygons_detector( + ltotal_range: LtotalRange[RunType, snx.NXdetector], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + frames: ChopperFrameSequence, +) -> LookupTable[RunType, snx.NXdetector]: + """ + Wrapper around _make_wavelength_lut_from_polygons to specify the Component as + snx.NXdetector, for use in the detector workflow. + """ + return LookupTable[RunType, snx.NXdetector]( + _make_wavelength_lut_from_polygons( + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + frames=frames, + ) + ) + + +def make_wavelength_lut_from_polygons_monitor( + ltotal_range: LtotalRange[RunType, MonitorType], + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + frames: ChopperFrameSequence, +) -> LookupTable[RunType, MonitorType]: + """ + Wrapper around _make_wavelength_lut_from_polygons to specify the Component as + snx.NXmonitor, for use in the monitor workflow. + """ + return LookupTable[RunType, MonitorType]( + _make_wavelength_lut_from_polygons( + ltotal_range=ltotal_range, + distance_resolution=distance_resolution, + time_resolution=time_resolution, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + frames=frames, + ) + ) + + +def providers() -> tuple[Callable]: + """ + Return the providers for creating the wavelength lookup table. We only include the + provider for computing the lookup table from the polygons, as using the simulation + to compute the lookup table is expensive and not something we want to do by + default. + """ + return ( + ltotal_range_from_ltotal_detector, + ltotal_range_from_ltotal_monitor, + make_wavelength_lut_from_polygons_detector, + make_wavelength_lut_from_polygons_monitor, + compute_frame_sequence, + ) + + +def default_parameters() -> dict: + return { + PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), + PulseStride: 1, + DistanceResolution: sc.scalar(0.1, unit="m"), + TimeResolution: sc.scalar(50.0, unit='us'), + SourceBounds: SourceBounds( + time=(sc.scalar(0.0, unit='ms'), sc.scalar(5.0, unit='ms')), + wavelength=( + sc.scalar(0.0, unit='angstrom'), + sc.scalar(15.0, unit='angstrom'), + ), + ), + } + + def LookupTableWorkflow(use_simulation: bool = True): """ Create a workflow for computing a wavelength lookup table. @@ -760,15 +949,15 @@ def LookupTableWorkflow(use_simulation: bool = True): through a chopper cascade using the ``tof`` package, or from the acceptance diagram polygons generated by the ``chopper_cascade`` module. """ - default_params = { - PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), - PulseStride: 1, - DistanceResolution: sc.scalar(0.1, unit="m"), - TimeResolution: sc.scalar(250.0, unit='us'), - } - + default_params = default_parameters() if use_simulation: - providers = (make_wavelength_lookup_table, simulate_chopper_cascade_using_tof) + provs = ( + ltotal_range_from_ltotal_detector, + ltotal_range_from_ltotal_monitor, + make_wavelength_lut_from_simulation_detector, + make_wavelength_lut_from_simulation_monitor, + simulate_chopper_cascade_using_tof, + ) default_params.update( { NumberOfSimulatedNeutrons: 1_000_000, @@ -777,17 +966,6 @@ def LookupTableWorkflow(use_simulation: bool = True): } ) else: - providers = (make_wavelength_lut_from_polygons, compute_frame_sequence) - default_params.update( - { - SourceBounds: SourceBounds( - time=(sc.scalar(0.0, unit='ms'), sc.scalar(5.0, unit='ms')), - wavelength=( - sc.scalar(0.0, unit='angstrom'), - sc.scalar(15.0, unit='angstrom'), - ), - ) - } - ) + provs = providers() - return sl.Pipeline(providers, params=default_params) + return sl.Pipeline(provs, params=default_params) diff --git a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py index e3a40e777..bf56f7307 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py +++ b/packages/essreduce/src/ess/reduce/unwrap/to_wavelength.py @@ -40,6 +40,7 @@ ErrorLimitedLookupTable, LookupTable, LookupTableRelativeErrorThreshold, + Lut, MonitorLtotal, PulseStrideOffset, WavelengthDetector, @@ -137,7 +138,7 @@ def __call__( def _compute_wavelength_histogram( - da: sc.DataArray, lookup: ErrorLimitedLookupTable, ltotal: sc.Variable + da: sc.DataArray, lookup: Lut, ltotal: sc.Variable ) -> sc.DataArray: # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files, # it may be called 'tof' or 'frame_time'. @@ -243,7 +244,7 @@ def _guess_pulse_stride_offset( def _prepare_wavelength_interpolation_inputs( da: sc.DataArray, - lookup: ErrorLimitedLookupTable, + lookup: Lut, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> dict: @@ -336,7 +337,7 @@ def _prepare_wavelength_interpolation_inputs( def _compute_wavelength_events( da: sc.DataArray, - lookup: ErrorLimitedLookupTable, + lookup: Lut, ltotal: sc.Variable, pulse_stride_offset: int | None, ) -> sc.DataArray: @@ -435,9 +436,7 @@ def monitor_ltotal_from_straight_line_approximation( ) -def _mask_large_uncertainty_in_lut( - table: LookupTable, error_threshold: float -) -> LookupTable: +def _mask_large_uncertainty_in_lut(table: Lut, error_threshold: float) -> Lut: """ Mask regions in the lookup table with large uncertainty using NaNs. @@ -452,7 +451,7 @@ def _mask_large_uncertainty_in_lut( da = table.array relative_error = sc.stddevs(da.data) / sc.values(da.data) mask = relative_error > sc.scalar(error_threshold) - return LookupTable( + return Lut( **{ **asdict(table), "array": sc.where(mask, sc.scalar(np.nan, unit=da.unit), da), @@ -461,10 +460,10 @@ def _mask_large_uncertainty_in_lut( def mask_large_uncertainty_in_lut_detector( - table: LookupTable, + table: LookupTable[RunType, snx.NXdetector], error_threshold: LookupTableRelativeErrorThreshold, detector_name: NeXusDetectorName, -) -> ErrorLimitedLookupTable[snx.NXdetector]: +) -> ErrorLimitedLookupTable[RunType, snx.NXdetector]: """ Mask regions in the wavelength lookup table with large uncertainty using NaNs. @@ -479,7 +478,7 @@ def mask_large_uncertainty_in_lut_detector( Name of the detector for which to apply the error threshold. This is used to get the correct error threshold from the dictionary of error thresholds. """ - return ErrorLimitedLookupTable[snx.NXdetector]( + return ErrorLimitedLookupTable[RunType, snx.NXdetector]( _mask_large_uncertainty_in_lut( table=table, error_threshold=error_threshold[detector_name] ) @@ -487,10 +486,10 @@ def mask_large_uncertainty_in_lut_detector( def mask_large_uncertainty_in_lut_monitor( - table: LookupTable, + table: LookupTable[RunType, MonitorType], error_threshold: LookupTableRelativeErrorThreshold, monitor_name: NeXusName[MonitorType], -) -> ErrorLimitedLookupTable[MonitorType]: +) -> ErrorLimitedLookupTable[RunType, MonitorType]: """ Mask regions in the wavelength lookup table with large uncertainty using NaNs. @@ -505,7 +504,7 @@ def mask_large_uncertainty_in_lut_monitor( Name of the monitor for which to apply the error threshold. This is used to get the correct error threshold from the dictionary of error thresholds. """ - return ErrorLimitedLookupTable[MonitorType]( + return ErrorLimitedLookupTable[RunType, MonitorType]( _mask_large_uncertainty_in_lut( table=table, error_threshold=error_threshold[monitor_name] ) @@ -514,7 +513,7 @@ def mask_large_uncertainty_in_lut_monitor( def _compute_wavelength_data( da: sc.DataArray, - lookup: ErrorLimitedLookupTable[Component], + lookup: ErrorLimitedLookupTable[RunType, Component], ltotal: sc.Variable, pulse_stride_offset: int, ) -> sc.DataArray: @@ -533,7 +532,7 @@ def _compute_wavelength_data( def detector_wavelength_data( detector_data: RawDetector[RunType], - lookup: ErrorLimitedLookupTable[snx.NXdetector], + lookup: ErrorLimitedLookupTable[RunType, snx.NXdetector], ltotal: DetectorLtotal[RunType], pulse_stride_offset: PulseStrideOffset, ) -> WavelengthDetector[RunType]: @@ -568,7 +567,7 @@ def detector_wavelength_data( def monitor_wavelength_data( monitor_data: RawMonitor[RunType, MonitorType], - lookup: ErrorLimitedLookupTable[MonitorType], + lookup: ErrorLimitedLookupTable[RunType, MonitorType], ltotal: MonitorLtotal[RunType, MonitorType], pulse_stride_offset: PulseStrideOffset, ) -> WavelengthMonitor[RunType, MonitorType]: diff --git a/packages/essreduce/src/ess/reduce/unwrap/types.py b/packages/essreduce/src/ess/reduce/unwrap/types.py index 86dd086c9..354f0c7eb 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/types.py +++ b/packages/essreduce/src/ess/reduce/unwrap/types.py @@ -15,9 +15,10 @@ @dataclass -class LookupTable: +class Lut: """ - Lookup table giving wavelength as a function of distance and ``event_time_offset``. + Base class for a lookup table giving wavelength as a function of distance and + ``event_time_offset``. """ array: sc.DataArray @@ -44,7 +45,12 @@ def plot(self, *args, **kwargs) -> Any: return self.array.plot(*args, **kwargs) -class ErrorLimitedLookupTable(sl.Scope[Component, LookupTable], LookupTable): +class LookupTable(sl.Scope[RunType, Component, Lut], Lut): + """Lookup table giving wavelength as a function of distance and + ``event_time_offset`` for each beamline component (detector, monitor).""" + + +class ErrorLimitedLookupTable(sl.Scope[RunType, Component, Lut], Lut): """Lookup table that is masked with NaNs in regions where the standard deviation of the wavelength is above a certain threshold.""" diff --git a/packages/essreduce/src/ess/reduce/unwrap/workflow.py b/packages/essreduce/src/ess/reduce/unwrap/workflow.py index b493162f5..9da6b710c 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/workflow.py +++ b/packages/essreduce/src/ess/reduce/unwrap/workflow.py @@ -6,7 +6,7 @@ import scipp as sc from ..nexus import GenericNeXusWorkflow -from . import to_wavelength +from . import lut, to_wavelength from .types import LookupTable, LookupTableFilename, PulseStrideOffset @@ -82,12 +82,14 @@ def GenericUnwrapWorkflow( """ wf = GenericNeXusWorkflow(run_types=run_types, monitor_types=monitor_types) - for provider in to_wavelength.providers(): + for provider in (*to_wavelength.providers(), *lut.providers()): wf.insert(provider) - wf.insert(load_lookup_table) + # wf.insert(load_lookup_table) # Default parameters wf[PulseStrideOffset] = None + for key, value in lut.default_parameters().items(): + wf[key] = value return wf