1111from easyreflectometry .data import DataSet1D
1212from easyreflectometry .model import Model
1313
14+ _VALID_OBJECTIVES = ('legacy_mask' , 'mighell' , 'hybrid' , 'auto' )
15+ _EPS = 1e-30
16+
17+
18+ def _validate_objective (objective : str ) -> str :
19+ """Validate and resolve the objective string.
20+
21+ :param objective: The objective mode string.
22+ :type objective: str
23+ :return: Resolved objective string ('auto' becomes 'hybrid').
24+ :rtype: str
25+ :raises ValueError: If the objective is not one of the valid options.
26+ """
27+ if objective not in _VALID_OBJECTIVES :
28+ raise ValueError (f'Unknown objective { objective !r} . Valid options: { _VALID_OBJECTIVES } ' )
29+ if objective == 'auto' :
30+ return 'hybrid'
31+ return objective
32+
33+
34+ def _prepare_fit_arrays (
35+ x_vals : np .ndarray ,
36+ y_vals : np .ndarray ,
37+ variances : np .ndarray ,
38+ objective : str ,
39+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray , dict ]:
40+ """Prepare x, y_eff, and weights arrays for fitting based on the objective mode.
41+
42+ For ``legacy_mask``, zero-variance points are removed from all arrays.
43+ For ``hybrid``, valid-variance points use standard WLS while zero-variance
44+ points use Mighell-transformed y and weights.
45+ For ``mighell``, all points use the Mighell transform.
46+
47+ Note: ``variances`` here means σ² (the scipp convention), not σ.
48+
49+ :param x_vals: Independent variable values.
50+ :type x_vals: np.ndarray
51+ :param y_vals: Observed dependent variable values.
52+ :type y_vals: np.ndarray
53+ :param variances: Variance (σ²) of each observed point.
54+ :type variances: np.ndarray
55+ :param objective: One of 'legacy_mask', 'hybrid', 'mighell'.
56+ :type objective: str
57+ :return: Tuple of (x_out, y_eff, weights, stats) where stats is a dict
58+ with keys 'valid', 'mighell_substituted', 'masked'.
59+ :rtype: tuple[np.ndarray, np.ndarray, np.ndarray, dict]
60+ """
61+ n = len (y_vals )
62+ zero_mask = variances <= 0.0
63+ n_zero = int (np .sum (zero_mask ))
64+ n_valid = n - n_zero
65+
66+ if objective == 'legacy_mask' :
67+ valid = ~ zero_mask
68+ x_out = x_vals [valid ]
69+ y_eff = y_vals [valid ]
70+ if n_valid > 0 :
71+ weights = 1.0 / np .sqrt (variances [valid ])
72+ else :
73+ weights = np .array ([])
74+ stats = {'valid' : n_valid , 'mighell_substituted' : 0 , 'masked' : n_zero }
75+ return x_out , y_eff , weights , stats
76+
77+ # hybrid or mighell
78+ y_eff = np .copy (y_vals )
79+ sigma = np .empty (n )
80+
81+ if objective == 'mighell' :
82+ apply_mighell = np .ones (n , dtype = bool )
83+ else :
84+ # hybrid: apply Mighell only to zero-variance points
85+ apply_mighell = zero_mask
86+
87+ # Standard WLS for non-Mighell points
88+ standard = ~ apply_mighell
89+ if np .any (standard ):
90+ sigma [standard ] = np .sqrt (variances [standard ])
91+
92+ # Mighell transform for selected points
93+ if np .any (apply_mighell ):
94+ y_m = y_vals [apply_mighell ]
95+ delta = np .minimum (y_m , 1.0 )
96+ y_eff [apply_mighell ] = y_m + delta
97+ sigma [apply_mighell ] = np .sqrt (np .maximum (y_m + 1.0 , _EPS ))
98+
99+ weights = 1.0 / sigma
100+ n_mighell = int (np .sum (apply_mighell ))
101+ stats = {'valid' : n - n_mighell , 'mighell_substituted' : n_mighell , 'masked' : 0 }
102+ return x_vals , y_eff , weights , stats
103+
14104
15105class MultiFitter :
16- def __init__ (self , * args : Model ):
17- r"""A convinence class for the :py:class:`easyscience.Fitting.Fitting`
106+ def __init__ (self , * args : Model , objective : str = 'hybrid' ):
107+ r"""A convenience class for the :py:class:`easyscience.Fitting.Fitting`
18108 which will populate the :py:class:`sc.DataGroup` appropriately
19109 after the fitting is performed.
20110
21- :param args: Reflectometry model
111+ :param args: Reflectometry model(s).
112+ :param objective: Zero-variance handling strategy. One of
113+ ``'hybrid'`` (default, Mighell for zero-variance, WLS otherwise),
114+ ``'mighell'`` (Mighell transform for all points),
115+ ``'legacy_mask'`` (drop zero-variance points),
116+ ``'auto'`` (alias for ``'hybrid'``).
117+ :type objective: str
22118 """
23119
24120 # This lets the unique_name be passed with the fit_func.
@@ -32,18 +128,29 @@ def wrapped(*args, **kwargs):
32128 self ._models = args
33129 self .easy_science_multi_fitter = EasyScienceMultiFitter (args , self ._fit_func )
34130 self ._fit_results : list [FitResults ] | None = None
35-
36- def fit (self , data : sc .DataGroup , id : int = 0 ) -> sc .DataGroup :
131+ self ._objective = _validate_objective (objective )
132+
133+ def fit (self , data : sc .DataGroup , id : int = 0 , objective : str | None = None ) -> sc .DataGroup :
134+ """Perform the fitting and populate the DataGroups with the result.
135+
136+ :param data: DataGroup to be fitted to and populated.
137+ :type data: sc.DataGroup
138+ :param id: Unused parameter kept for backward compatibility.
139+ :type id: int
140+ :param objective: Per-call override for the zero-variance objective.
141+ If ``None``, uses the instance default set at construction.
142+ :type objective: str or None
143+ :return: A new DataGroup with fitted model curves, SLD profiles, and fit statistics.
144+ :rtype: sc.DataGroup
145+
146+ :note: Under the ``mighell`` objective all points are transformed,
147+ so ``reduced_chi`` is not a classical chi-square statistic.
148+ Under ``hybrid``, only zero-variance points are transformed;
149+ when they are a small fraction of the data the chi-square
150+ remains approximately classical.
37151 """
38- Perform the fitting and populate the DataGroups with the result.
152+ obj = _validate_objective ( objective ) if objective is not None else self . _objective
39153
40- :param data: DataGroup to be fitted to and populated
41- :param method: Optimisation method
42-
43- :note: Points with zero variance in the data will be automatically masked
44- out during fitting. A warning will be issued if any such points
45- are found, indicating the number of points masked per reflectivity.
46- """
47154 refl_nums = [k [3 :] for k in data ['coords' ].keys () if 'Qz' == k [:2 ]]
48155 x = []
49156 y = []
@@ -55,25 +162,24 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
55162 y_vals = data ['data' ][f'R_{ i } ' ].values
56163 variances = data ['data' ][f'R_{ i } ' ].variances
57164
58- # Find points with non-zero variance
59- zero_variance_mask = variances == 0.0
60- num_zero_variance = np .sum (zero_variance_mask )
165+ x_out , y_eff , weights , stats = _prepare_fit_arrays (x_vals , y_vals , variances , obj )
61166
62- if num_zero_variance > 0 :
167+ if stats [ 'masked' ] > 0 :
63168 warnings .warn (
64- f'Masked { num_zero_variance } data point(s) in reflectivity { i } due to zero variance during fitting.' ,
169+ f'Masked { stats ["masked" ]} data point(s) in reflectivity { i } '
170+ 'due to zero variance during fitting.' ,
171+ UserWarning ,
172+ )
173+ if stats ['mighell_substituted' ] > 0 :
174+ warnings .warn (
175+ f'Applied Mighell substitution to { stats ["mighell_substituted" ]} '
176+ f'zero-variance point(s) in reflectivity { i } during fitting.' ,
65177 UserWarning ,
66178 )
67179
68- # Keep only points with non-zero variances
69- valid_mask = ~ zero_variance_mask
70- x_vals_masked = x_vals [valid_mask ]
71- y_vals_masked = y_vals [valid_mask ]
72- variances_masked = variances [valid_mask ]
73-
74- x .append (x_vals_masked )
75- y .append (y_vals_masked )
76- dy .append (1 / np .sqrt (variances_masked ))
180+ x .append (x_out )
181+ y .append (y_eff )
182+ dy .append (weights )
77183
78184 result = self .easy_science_multi_fitter .fit (x , y , weights = dy )
79185 self ._fit_results = result
@@ -94,36 +200,43 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
94200 new_data ['success' ] = result [i ].success
95201 return new_data
96202
97- def fit_single_data_set_1d (self , data : DataSet1D ) -> FitResults :
203+ def fit_single_data_set_1d (self , data : DataSet1D , objective : str | None = None ) -> FitResults :
204+ """Perform fitting on a single 1D dataset.
205+
206+ :param data: The 1D dataset to fit. Note that ``data.ye`` stores
207+ variances (σ²), not standard deviations.
208+ :type data: DataSet1D
209+ :param objective: Per-call override for the zero-variance objective.
210+ If ``None``, uses the instance default set at construction.
211+ :type objective: str or None
212+ :return: Fit results from the minimizer.
213+ :rtype: FitResults
98214 """
99- Perform the fitting and populate the DataGroups with the result.
215+ obj = _validate_objective ( objective ) if objective is not None else self . _objective
100216
101- :param data: DataGroup to be fitted to and populated
102- :param method: Optimisation method
103- """
104217 x_vals = np .asarray (data .x )
105218 y_vals = np .asarray (data .y )
106219 variances = np .asarray (data .ye )
107220
108- zero_variance_mask = variances == 0.0
109- num_zero_variance = int (np .sum (zero_variance_mask ))
221+ x_out , y_eff , weights , stats = _prepare_fit_arrays (x_vals , y_vals , variances , obj )
110222
111- if num_zero_variance > 0 :
223+ if stats [ 'masked' ] > 0 :
112224 warnings .warn (
113- f'Masked { num_zero_variance } data point(s) in single-dataset fit due to zero variance during fitting.' ,
225+ f'Masked { stats ["masked" ]} data point(s) in single-dataset fit '
226+ 'due to zero variance during fitting.' ,
227+ UserWarning ,
228+ )
229+ if stats ['mighell_substituted' ] > 0 :
230+ warnings .warn (
231+ f'Applied Mighell substitution to { stats ["mighell_substituted" ]} '
232+ 'zero-variance point(s) in single-dataset fit during fitting.' ,
114233 UserWarning ,
115234 )
116235
117- valid_mask = ~ zero_variance_mask
118- if not np .any (valid_mask ):
236+ if obj == 'legacy_mask' and len (x_out ) == 0 :
119237 raise ValueError ('Cannot fit single dataset: all points have zero variance.' )
120238
121- x_vals_masked = x_vals [valid_mask ]
122- y_vals_masked = y_vals [valid_mask ]
123- variances_masked = variances [valid_mask ]
124-
125- weights = 1.0 / np .sqrt (variances_masked )
126- result = self .easy_science_multi_fitter .fit (x = [x_vals_masked ], y = [y_vals_masked ], weights = [weights ])[0 ]
239+ result = self .easy_science_multi_fitter .fit (x = [x_out ], y = [y_eff ], weights = [weights ])[0 ]
127240 self ._fit_results = [result ]
128241 return result
129242
0 commit comments