Skip to content

Commit b65486c

Browse files
authored
Improved sample handling (#293)
* fix for appending samples from loaded models * added unit tests * use most recent version of refl1d * update docs to include add_sample_from_orso * temporarily revert to fwhm * minor fixes after CR + ruff format * back to pointwise * proper conversion from stack to layers. * added method * temporarily switch off pointwise * fixed SLD read. * added handling for orso names (model and experiment) * minor fix for names * CR review fixes * model reindexing issue fixed. Project rst added * path for SLD-less orso files * fix variances sent to bumps. Make FWHM default (temporarily) * fixed test to correspond to new resolution * added logs for changing minimizers * reload constraints where necessary * expose chi2 properly * PR review #1 * Fix the zero variances issue * allow 3.13 * added some tests
1 parent 70066df commit b65486c

11 files changed

Lines changed: 258 additions & 48 deletions

File tree

.copier-answers.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
_commit: 8bdcedc
33
_src_path: gh:/EasyScience/EasyProjectTemplate
44
description: A reflectometry python package built on the EasyScience framework.
5-
max_python: '3.12'
5+
max_python: '3.13'
66
min_python: '3.9'
77
orgname: EasyScience
88
packagename: easyreflectometry

.github/workflows/python-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
strategy:
2929
max-parallel: 4
3030
matrix:
31-
python-version: ['3.11', '3.12']
31+
python-version: ['3.11', '3.12', '3.13']
3232
os: [ubuntu-latest, macos-latest, windows-2022]
3333

3434
runs-on: ${{ matrix.os }}

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
runs-on: ubuntu-latest
1919
strategy:
2020
matrix:
21-
python-version: ['3.11','3.12']
21+
python-version: ['3.11','3.12','3.13']
2222
if: "!contains(github.event.head_commit.message, '[ci skip]')"
2323

2424
steps:

CONTRIBUTING.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Before you submit a pull request, check that it meets these guidelines:
102102
2. If the pull request adds functionality, the docs should be updated. Put
103103
your new functionality into a function with a docstring, and add the
104104
feature to the list in README.md.
105-
3. The pull request should work for Python, 3.11 and 3.12, and for PyPy. Check
105+
3. The pull request should work for Python, 3.11, 3.12, and 3.13, and for PyPy. Check
106106
https://travis-ci.com/easyScience/EasyReflectometryLib/pull_requests
107107
and make sure that the tests pass for all supported Python versions.
108108

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ classifiers = [
2323
"Programming Language :: Python :: 3 :: Only",
2424
"Programming Language :: Python :: 3.11",
2525
"Programming Language :: Python :: 3.12",
26+
"Programming Language :: Python :: 3.13",
2627
"Development Status :: 3 - Alpha"
2728
]
2829

29-
requires-python = ">=3.11,<3.13"
30+
requires-python = ">=3.11,<3.14"
3031

3132
dependencies = [
3233
"easyscience @ git+https://github.com/easyscience/corelib.git@develop",
@@ -134,11 +135,12 @@ force-single-line = true
134135
legacy_tox_ini = """
135136
[tox]
136137
isolated_build = True
137-
envlist = py{3.11,3.12}
138+
envlist = py{3.11,3.12,3.13}
138139
[gh-actions]
139140
python =
140141
3.11: py311
141142
3.12: py312
143+
3.13: py313
142144
[gh-actions:env]
143145
PLATFORM =
144146
ubuntu-latest: linux

src/easyreflectometry/fitting.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def wrapped(*args, **kwargs):
3131
self._fit_func = [func_wrapper(m.interface.fit_func, m.unique_name) for m in args]
3232
self._models = args
3333
self.easy_science_multi_fitter = EasyScienceMultiFitter(args, self._fit_func)
34+
self._fit_results: list[FitResults] | None = None
3435

3536
def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
3637
"""
@@ -75,6 +76,7 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
7576
dy.append(1 / np.sqrt(variances_masked))
7677

7778
result = self.easy_science_multi_fitter.fit(x, y, weights=dy)
79+
self._fit_results = result
7880
new_data = data.copy()
7981
for i, _ in enumerate(result):
8082
id = refl_nums[i]
@@ -99,7 +101,53 @@ def fit_single_data_set_1d(self, data: DataSet1D) -> FitResults:
99101
:param data: DataGroup to be fitted to and populated
100102
:param method: Optimisation method
101103
"""
102-
return self.easy_science_multi_fitter.fit(x=[data.x], y=[data.y], weights=[data.ye])[0]
104+
x_vals = np.asarray(data.x)
105+
y_vals = np.asarray(data.y)
106+
variances = np.asarray(data.ye)
107+
108+
zero_variance_mask = variances == 0.0
109+
num_zero_variance = int(np.sum(zero_variance_mask))
110+
111+
if num_zero_variance > 0:
112+
warnings.warn(
113+
f'Masked {num_zero_variance} data point(s) in single-dataset fit due to zero variance during fitting.',
114+
UserWarning,
115+
)
116+
117+
valid_mask = ~zero_variance_mask
118+
if not np.any(valid_mask):
119+
raise ValueError('Cannot fit single dataset: all points have zero variance.')
120+
121+
x_vals_masked = x_vals[valid_mask]
122+
y_vals_masked = y_vals[valid_mask]
123+
variances_masked = variances[valid_mask]
124+
125+
weights = 1.0 / np.sqrt(variances_masked)
126+
result = self.easy_science_multi_fitter.fit(x=[x_vals_masked], y=[y_vals_masked], weights=[weights])[0]
127+
self._fit_results = [result]
128+
return result
129+
130+
@property
131+
def chi2(self) -> float | None:
132+
"""Total chi-squared across all fitted datasets, or None if no fit has been performed."""
133+
if self._fit_results is None:
134+
return None
135+
return sum(r.chi2 for r in self._fit_results)
136+
137+
@property
138+
def reduced_chi(self) -> float | None:
139+
"""Reduced chi-squared from the most recent fit, or None if no fit has been performed."""
140+
if self._fit_results is None:
141+
return None
142+
total_chi2 = sum(r.chi2 for r in self._fit_results)
143+
total_points = sum(np.size(r.x) for r in self._fit_results)
144+
n_params = self._fit_results[0].n_pars
145+
total_dof = total_points - n_params
146+
147+
if total_dof <= 0:
148+
return None
149+
150+
return total_chi2 / total_dof
103151

104152
def switch_minimizer(self, minimizer: AvailableMinimizers) -> None:
105153
"""

src/easyreflectometry/project.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import json
3+
import logging
34
import os
45
from pathlib import Path
56
from typing import Dict
@@ -10,8 +11,8 @@
1011
import numpy as np
1112
from easyscience import global_object
1213
from easyscience.fitting import AvailableMinimizers
13-
from easyscience.fitting.fitter import DEFAULT_MINIMIZER
1414
from easyscience.variable import Parameter
15+
from easyscience.variable.parameter_dependency_resolver import resolve_all_parameter_dependencies
1516
from scipp import DataGroup
1617

1718
from easyreflectometry.calculators import CalculatorFactory
@@ -20,23 +21,23 @@
2021
from easyreflectometry.data.measurement import extract_orso_title
2122
from easyreflectometry.data.measurement import load_data_from_orso_file
2223
from easyreflectometry.fitting import MultiFitter
23-
from easyreflectometry.model import LinearSpline
2424
from easyreflectometry.model import Model
2525
from easyreflectometry.model import ModelCollection
2626
from easyreflectometry.model import PercentageFwhm
27-
from easyreflectometry.model import Pointwise
2827
from easyreflectometry.sample import Layer
2928
from easyreflectometry.sample import Material
3029
from easyreflectometry.sample import MaterialCollection
3130
from easyreflectometry.sample import Multilayer
3231
from easyreflectometry.sample import Sample
3332
from easyreflectometry.sample.collections.base_collection import BaseCollection
3433

34+
logger = logging.getLogger(__name__)
35+
3536
Q_MIN = 0.001
3637
Q_MAX = 0.3
3738
Q_RESOLUTION = 500
3839

39-
DEFAULT_MINIZER = AvailableMinimizers.LMFit_leastsq
40+
DEFAULT_MINIMIZER = AvailableMinimizers.LMFit_leastsq
4041

4142

4243
class Project:
@@ -48,6 +49,7 @@ def __init__(self):
4849
self._calculator = CalculatorFactory()
4950
self._experiments: Dict[DataGroup] = {}
5051
self._fitter: MultiFitter = None
52+
self._minimizer_selection: AvailableMinimizers = DEFAULT_MINIMIZER
5153
self._colors: list[str] = None
5254
self._report = None
5355
self._q_min: float = None
@@ -207,9 +209,8 @@ def models(self, models: ModelCollection) -> None:
207209
def fitter(self) -> MultiFitter:
208210
if len(self._models):
209211
if (self._fitter is None) or (self._fitter_model_index != self._current_model_index):
210-
minimizer = self.minimizer
211212
self._fitter = MultiFitter(self._models[self._current_model_index])
212-
self.minimizer = minimizer
213+
self._fitter.easy_science_multi_fitter.switch_minimizer(self._minimizer_selection)
213214
self._fitter_model_index = self._current_model_index
214215
return self._fitter
215216

@@ -225,10 +226,14 @@ def calculator(self, calculator: str) -> None:
225226
def minimizer(self) -> AvailableMinimizers:
226227
if self._fitter is not None:
227228
return self._fitter.easy_science_multi_fitter.minimizer.enum
228-
return DEFAULT_MINIMIZER
229+
return self._minimizer_selection
229230

230231
@minimizer.setter
231232
def minimizer(self, minimizer: AvailableMinimizers) -> None:
233+
old_name = getattr(self._minimizer_selection, 'name', str(self._minimizer_selection))
234+
new_name = getattr(minimizer, 'name', str(minimizer))
235+
logger.info('Minimizer changed from %s to %s (fitter active: %s)', old_name, new_name, self._fitter is not None)
236+
self._minimizer_selection = minimizer
232237
if self._fitter is not None:
233238
self._fitter.easy_science_multi_fitter.switch_minimizer(minimizer)
234239

@@ -386,21 +391,10 @@ def _apply_resolution_function(
386391
) -> None:
387392
"""Set the resolution function on *model* based on variance data in *experiment*.
388393
389-
Prefers Pointwise when q-resolution (xe) data is present, otherwise falls
390-
back to LinearSpline when reflectivity error (ye) data is present.
391-
392394
:param experiment: The experiment whose variance data drives the choice.
393395
:param model: The model whose resolution function is set.
394396
"""
395-
if sum(experiment.xe) != 0:
396-
resolution_function = Pointwise(q_data_points=[experiment.x, experiment.y, experiment.xe])
397-
model.resolution_function = resolution_function
398-
elif sum(experiment.ye) != 0:
399-
resolution_function = LinearSpline(
400-
q_data_points=experiment.x,
401-
fwhm_values=np.sqrt(experiment.ye),
402-
)
403-
model.resolution_function = resolution_function
397+
model.resolution_function = PercentageFwhm(5.0)
404398

405399
def load_new_experiment(self, path: Union[Path, str]) -> None:
406400
new_experiment = load_as_dataset(str(path))
@@ -603,6 +597,8 @@ def as_dict(self, include_materials_not_in_model=False):
603597
self._as_dict_add_experiments(project_dict)
604598
if self.fitter is not None:
605599
project_dict['fitter_minimizer'] = self.fitter.easy_science_multi_fitter.minimizer.name
600+
elif self._minimizer_selection is not None:
601+
project_dict['fitter_minimizer'] = self._minimizer_selection.name
606602
if self._calculator is not None:
607603
project_dict['calculator'] = self._calculator.current_interface_name
608604
if self._colors is not None:
@@ -641,14 +637,17 @@ def from_dict(self, project_dict: dict):
641637
if 'materials_not_in_model' in keys:
642638
self._materials.extend(MaterialCollection.from_dict(project_dict['materials_not_in_model']))
643639
if 'fitter_minimizer' in keys:
644-
self.fitter.easy_science_multi_fitter.switch_minimizer(AvailableMinimizers[project_dict['fitter_minimizer']])
640+
self.minimizer = AvailableMinimizers[project_dict['fitter_minimizer']]
645641
else:
646642
self._fitter = None
647643
if 'experiments' in keys:
648644
self._experiments = self._from_dict_extract_experiments(project_dict)
649645
else:
650646
self._experiments = {}
651647

648+
# Resolve any pending parameter dependencies (constraints) after all objects are loaded
649+
resolve_all_parameter_dependencies(self)
650+
652651
def _from_dict_extract_experiments(self, project_dict: dict) -> Dict[int, DataSet1D]:
653652
experiments = {}
654653
for key in project_dict['experiments'].keys():

tests/summary/test_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_experiments_section(self, project: Project) -> None:
133133
assert 'No. of data points' in html
134134
assert '408' in html
135135
assert 'Resolution function' in html
136-
assert 'Pointwise' in html
136+
assert 'PercentageFwhm' in html
137137

138138
def test_experiments_section_percentage_fhwm(self, project: Project) -> None:
139139
# When

0 commit comments

Comments
 (0)