Skip to content

Commit 717ef1e

Browse files
committed
initial implementation of the mighell algorithm for data with zero variances
1 parent 41c1c5f commit 717ef1e

4 files changed

Lines changed: 1354 additions & 67 deletions

File tree

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# Version 1.4.0 (unreleased)
2+
3+
Add Mighell-based handling of zero-variance points in fitting (issue #256).
4+
Zero-variance data points are no longer forcibly discarded; instead, a hybrid
5+
objective applies a Mighell substitution for zero-variance points while using
6+
standard weighted least squares for the rest. The previous masking behavior is
7+
available via `objective='legacy_mask'`. New `objective` parameter on
8+
`MultiFitter`, `fit()`, and `fit_single_data_set_1d()`.
9+
110
# Version 1.3.3 (17 June 2025)
211

312
Added Chi^2 and fit status to fitting results.

notebooks/zero_variance_fitting.ipynb

Lines changed: 828 additions & 0 deletions
Large diffs are not rendered by default.

src/easyreflectometry/fitting.py

Lines changed: 157 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,110 @@
1111
from easyreflectometry.data import DataSet1D
1212
from easyreflectometry.model import Model
1313

14+
_VALID_OBJECTIVES = ('legacy_mask', 'mighell', 'hybrid', 'auto')
15+
_EPS = 1e-30
16+
17+
18+
def _validate_objective(objective: str) -> str:
19+
"""Validate and resolve the objective string.
20+
21+
:param objective: The objective mode string.
22+
:type objective: str
23+
:return: Resolved objective string ('auto' becomes 'hybrid').
24+
:rtype: str
25+
:raises ValueError: If the objective is not one of the valid options.
26+
"""
27+
if objective not in _VALID_OBJECTIVES:
28+
raise ValueError(f'Unknown objective {objective!r}. Valid options: {_VALID_OBJECTIVES}')
29+
if objective == 'auto':
30+
return 'hybrid'
31+
return objective
32+
33+
34+
def _prepare_fit_arrays(
35+
x_vals: np.ndarray,
36+
y_vals: np.ndarray,
37+
variances: np.ndarray,
38+
objective: str,
39+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
40+
"""Prepare x, y_eff, and weights arrays for fitting based on the objective mode.
41+
42+
For ``legacy_mask``, zero-variance points are removed from all arrays.
43+
For ``hybrid``, valid-variance points use standard WLS while zero-variance
44+
points use Mighell-transformed y and weights.
45+
For ``mighell``, all points use the Mighell transform.
46+
47+
Note: ``variances`` here means σ² (the scipp convention), not σ.
48+
49+
:param x_vals: Independent variable values.
50+
:type x_vals: np.ndarray
51+
:param y_vals: Observed dependent variable values.
52+
:type y_vals: np.ndarray
53+
:param variances: Variance (σ²) of each observed point.
54+
:type variances: np.ndarray
55+
:param objective: One of 'legacy_mask', 'hybrid', 'mighell'.
56+
:type objective: str
57+
:return: Tuple of (x_out, y_eff, weights, stats) where stats is a dict
58+
with keys 'valid', 'mighell_substituted', 'masked'.
59+
:rtype: tuple[np.ndarray, np.ndarray, np.ndarray, dict]
60+
"""
61+
n = len(y_vals)
62+
zero_mask = variances <= 0.0
63+
n_zero = int(np.sum(zero_mask))
64+
n_valid = n - n_zero
65+
66+
if objective == 'legacy_mask':
67+
valid = ~zero_mask
68+
x_out = x_vals[valid]
69+
y_eff = y_vals[valid]
70+
if n_valid > 0:
71+
weights = 1.0 / np.sqrt(variances[valid])
72+
else:
73+
weights = np.array([])
74+
stats = {'valid': n_valid, 'mighell_substituted': 0, 'masked': n_zero}
75+
return x_out, y_eff, weights, stats
76+
77+
# hybrid or mighell
78+
y_eff = np.copy(y_vals)
79+
sigma = np.empty(n)
80+
81+
if objective == 'mighell':
82+
apply_mighell = np.ones(n, dtype=bool)
83+
else:
84+
# hybrid: apply Mighell only to zero-variance points
85+
apply_mighell = zero_mask
86+
87+
# Standard WLS for non-Mighell points
88+
standard = ~apply_mighell
89+
if np.any(standard):
90+
sigma[standard] = np.sqrt(variances[standard])
91+
92+
# Mighell transform for selected points
93+
if np.any(apply_mighell):
94+
y_m = y_vals[apply_mighell]
95+
delta = np.minimum(y_m, 1.0)
96+
y_eff[apply_mighell] = y_m + delta
97+
sigma[apply_mighell] = np.sqrt(np.maximum(y_m + 1.0, _EPS))
98+
99+
weights = 1.0 / sigma
100+
n_mighell = int(np.sum(apply_mighell))
101+
stats = {'valid': n - n_mighell, 'mighell_substituted': n_mighell, 'masked': 0}
102+
return x_vals, y_eff, weights, stats
103+
14104

15105
class MultiFitter:
16-
def __init__(self, *args: Model):
17-
r"""A convinence class for the :py:class:`easyscience.Fitting.Fitting`
106+
def __init__(self, *args: Model, objective: str = 'hybrid'):
107+
r"""A convenience class for the :py:class:`easyscience.Fitting.Fitting`
18108
which will populate the :py:class:`sc.DataGroup` appropriately
19109
after the fitting is performed.
20110
21-
:param args: Reflectometry model
111+
:param args: Reflectometry model(s).
112+
:param objective: Zero-variance handling strategy. One of
113+
``'hybrid'`` (default, Mighell for zero-variance, WLS otherwise),
114+
``'mighell'`` (Mighell transform for all points),
115+
``'legacy_mask'`` (drop zero-variance points),
116+
``'auto'`` (alias for ``'hybrid'``).
117+
:type objective: str
22118
"""
23119

24120
# This lets the unique_name be passed with the fit_func.
@@ -32,18 +128,29 @@ def wrapped(*args, **kwargs):
32128
self._models = args
33129
self.easy_science_multi_fitter = EasyScienceMultiFitter(args, self._fit_func)
34130
self._fit_results: list[FitResults] | None = None
35-
36-
def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
131+
self._objective = _validate_objective(objective)
132+
133+
def fit(self, data: sc.DataGroup, id: int = 0, objective: str | None = None) -> sc.DataGroup:
134+
"""Perform the fitting and populate the DataGroups with the result.
135+
136+
:param data: DataGroup to be fitted to and populated.
137+
:type data: sc.DataGroup
138+
:param id: Unused parameter kept for backward compatibility.
139+
:type id: int
140+
:param objective: Per-call override for the zero-variance objective.
141+
If ``None``, uses the instance default set at construction.
142+
:type objective: str or None
143+
:return: A new DataGroup with fitted model curves, SLD profiles, and fit statistics.
144+
:rtype: sc.DataGroup
145+
146+
:note: Under the ``mighell`` objective all points are transformed,
147+
so ``reduced_chi`` is not a classical chi-square statistic.
148+
Under ``hybrid``, only zero-variance points are transformed;
149+
when they are a small fraction of the data the chi-square
150+
remains approximately classical.
37151
"""
38-
Perform the fitting and populate the DataGroups with the result.
152+
obj = _validate_objective(objective) if objective is not None else self._objective
39153

40-
:param data: DataGroup to be fitted to and populated
41-
:param method: Optimisation method
42-
43-
:note: Points with zero variance in the data will be automatically masked
44-
out during fitting. A warning will be issued if any such points
45-
are found, indicating the number of points masked per reflectivity.
46-
"""
47154
refl_nums = [k[3:] for k in data['coords'].keys() if 'Qz' == k[:2]]
48155
x = []
49156
y = []
@@ -55,25 +162,24 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
55162
y_vals = data['data'][f'R_{i}'].values
56163
variances = data['data'][f'R_{i}'].variances
57164

58-
# Find points with non-zero variance
59-
zero_variance_mask = variances == 0.0
60-
num_zero_variance = np.sum(zero_variance_mask)
165+
x_out, y_eff, weights, stats = _prepare_fit_arrays(x_vals, y_vals, variances, obj)
61166

62-
if num_zero_variance > 0:
167+
if stats['masked'] > 0:
63168
warnings.warn(
64-
f'Masked {num_zero_variance} data point(s) in reflectivity {i} due to zero variance during fitting.',
169+
f'Masked {stats["masked"]} data point(s) in reflectivity {i} '
170+
'due to zero variance during fitting.',
171+
UserWarning,
172+
)
173+
if stats['mighell_substituted'] > 0:
174+
warnings.warn(
175+
f'Applied Mighell substitution to {stats["mighell_substituted"]} '
176+
f'zero-variance point(s) in reflectivity {i} during fitting.',
65177
UserWarning,
66178
)
67179

68-
# Keep only points with non-zero variances
69-
valid_mask = ~zero_variance_mask
70-
x_vals_masked = x_vals[valid_mask]
71-
y_vals_masked = y_vals[valid_mask]
72-
variances_masked = variances[valid_mask]
73-
74-
x.append(x_vals_masked)
75-
y.append(y_vals_masked)
76-
dy.append(1 / np.sqrt(variances_masked))
180+
x.append(x_out)
181+
y.append(y_eff)
182+
dy.append(weights)
77183

78184
result = self.easy_science_multi_fitter.fit(x, y, weights=dy)
79185
self._fit_results = result
@@ -94,36 +200,43 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
94200
new_data['success'] = result[i].success
95201
return new_data
96202

97-
def fit_single_data_set_1d(self, data: DataSet1D) -> FitResults:
203+
def fit_single_data_set_1d(self, data: DataSet1D, objective: str | None = None) -> FitResults:
204+
"""Perform fitting on a single 1D dataset.
205+
206+
:param data: The 1D dataset to fit. Note that ``data.ye`` stores
207+
variances (σ²), not standard deviations.
208+
:type data: DataSet1D
209+
:param objective: Per-call override for the zero-variance objective.
210+
If ``None``, uses the instance default set at construction.
211+
:type objective: str or None
212+
:return: Fit results from the minimizer.
213+
:rtype: FitResults
98214
"""
99-
Perform the fitting and populate the DataGroups with the result.
215+
obj = _validate_objective(objective) if objective is not None else self._objective
100216

101-
:param data: DataGroup to be fitted to and populated
102-
:param method: Optimisation method
103-
"""
104217
x_vals = np.asarray(data.x)
105218
y_vals = np.asarray(data.y)
106219
variances = np.asarray(data.ye)
107220

108-
zero_variance_mask = variances == 0.0
109-
num_zero_variance = int(np.sum(zero_variance_mask))
221+
x_out, y_eff, weights, stats = _prepare_fit_arrays(x_vals, y_vals, variances, obj)
110222

111-
if num_zero_variance > 0:
223+
if stats['masked'] > 0:
112224
warnings.warn(
113-
f'Masked {num_zero_variance} data point(s) in single-dataset fit due to zero variance during fitting.',
225+
f'Masked {stats["masked"]} data point(s) in single-dataset fit '
226+
'due to zero variance during fitting.',
227+
UserWarning,
228+
)
229+
if stats['mighell_substituted'] > 0:
230+
warnings.warn(
231+
f'Applied Mighell substitution to {stats["mighell_substituted"]} '
232+
'zero-variance point(s) in single-dataset fit during fitting.',
114233
UserWarning,
115234
)
116235

117-
valid_mask = ~zero_variance_mask
118-
if not np.any(valid_mask):
236+
if obj == 'legacy_mask' and len(x_out) == 0:
119237
raise ValueError('Cannot fit single dataset: all points have zero variance.')
120238

121-
x_vals_masked = x_vals[valid_mask]
122-
y_vals_masked = y_vals[valid_mask]
123-
variances_masked = variances[valid_mask]
124-
125-
weights = 1.0 / np.sqrt(variances_masked)
126-
result = self.easy_science_multi_fitter.fit(x=[x_vals_masked], y=[y_vals_masked], weights=[weights])[0]
239+
result = self.easy_science_multi_fitter.fit(x=[x_out], y=[y_eff], weights=[weights])[0]
127240
self._fit_results = [result]
128241
return result
129242

0 commit comments

Comments
 (0)