Skip to content

[1438] Adding padding to end_date to avoid duplicate samples#1749

Merged
enssow merged 37 commits intoecmwf:developfrom
enssow:sorcha/dev/1438
Apr 21, 2026
Merged

[1438] Adding padding to end_date to avoid duplicate samples#1749
enssow merged 37 commits intoecmwf:developfrom
enssow:sorcha/dev/1438

Conversation

@enssow
Copy link
Copy Markdown
Contributor

@enssow enssow commented Jan 29, 2026

Description

TimeWindowHandler doesn't produce enough available forecast initilisation times to choose for samples when run inference on a model trained with $n_{fstep}$ forecast steps and $n_{samples}*dt\geq t_{end} - t_{start}$.
Where $n_{fstep}=$--forecast_steps, $n_{samples}=$--samples, $dt=$--step_hours, $t_{start}=$--start, $t_{end}=$--end
(See #1438 and #1085) for more info

This PR provides this padding by working out how many available individual initialisation times there are and adjusting the end of the time window to accomodate that and taking into account the extra time needed to accomodate the number of forecast steps to rollout to

Issue Number

Closes #1438 #2059

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@enssow
Copy link
Copy Markdown
Contributor Author

enssow commented Jan 29, 2026

Tested on SANTIS for:

  • uv run inference --from-run-id f4duf5ji --samples 254 --streams-output ERA5 --options training_config.forecast.num_steps=5
  • uv run inference --from-run-id f4duf5ji --samples 254 --streams-output ERA5
    both now do not return duplication warning and inference_id cixysv6l was run to completion with success

@ankitpatnala
Copy link
Copy Markdown
Contributor

ankitpatnala commented Feb 3, 2026

Thanks @enssow for handling this issue
I tested the code using
srun uv run inference --from-run-id f4duf5ji --samples 10 -start="2022-10-01" -end="2022-10-02" --streams-output ERA5 --options training_config.forecast.num_steps=5
The code functioned way it has been described.

But I still do not know what will be a better strategy; should we pad with available dates or decrease the num_samples to defined range date. What if there is no data after the defined end_date. It will throw an error or return empty tensors.

@ankitpatnala
Copy link
Copy Markdown
Contributor

Can you run some inference with this options
srun uv run inference --from-run-id f4duf5ji --samples 10 -start=2022-10-01 -end=2022-10-02 --streams-output ERA5 --options training_config.forecast.num_steps=10 training_config.forecast.time_step=03:00:00

I saw some unwanted behaviour there

Logging set up. Logs are in ./output/uv3yi4ac
DDP initialization: rank=0, world_size=1
Using adjusted end date 2022-10-03T00:00:00.000000000 instead of 2022-10-02T00:00:00.000000000
TimeWindowHandler: start=2022-10-01T00:00:00.000000000, end=2022-10-03T00:00:00.000000000, len=06:00:00, step=06:00:00`

@github-actions github-actions Bot added bug Something isn't working data Anything related to the datasets used in the project model Related to model training or definition (not generic infra) model:inference anything related to the inference step (not plotting or score computation). labels Feb 19, 2026
@clessig clessig self-requested a review March 2, 2026 15:54
@enssow enssow marked this pull request as draft March 12, 2026 12:09
@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Mar 15, 2026

@enssow : what is the root cause of the length mismatch?

@enssow
Copy link
Copy Markdown
Contributor Author

enssow commented Mar 19, 2026

@clessig this is now ready for review, reverted some of the changes per the meeting - should i publicise it on the PR channel or would you prefer to give it a look? thank you in advance!

@enssow
Copy link
Copy Markdown
Contributor Author

enssow commented Apr 2, 2026

Expected behaviour: Given a specified date range (either given or using the default test_config) if the user requests more samples than are available (calculated using the length of time required to fit samples seperated by step_timedelta and the number of forecast steps from the last sample) the multistreamdatasampler will prioritise the specified date range and reduce the number of samples requested, raising a warning for the user. If the number of available samples is non-positive the user either needs to reduce the number of forecast steps or increase the date range and an AssertionError is raising that breaks the code.

@enssow
Copy link
Copy Markdown
Contributor Author

enssow commented Apr 2, 2026

Tested on SANTIS with varying start, end dates, num_samples and num_steps:
uv run inference --from-run-id wmway1iv --options training_config.forecast.num_steps=5 test_config.output.num_samples=5 test_config.start_date=2022-10-01 test_config.end_date=2022-10-03

Copy link
Copy Markdown
Contributor

@grassesi grassesi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate very much you breaking up the constructor. However I would like to take the opportunity to also reduce the excessive use of object state to pass around information that can be handled by returns instead. This will make the reset logic much more explicit and easier to follow.
Additionally I think there is a small oversight: When calculating the length of the batch index permutation array and when checking the number of available init_times (and thus number of available batches) you compare against self.samples_per_mini_epoch, which includes the additional factor batch_size.

I have implemented the suggestions here: https://github.com/grassesi/WeatherGenerator/tree/sgrasse/sorcha/1438-fix feel free to merge them if they look good to you. Please tell me if I overlooked anything.

Comment on lines +136 to +153
def _init_forecast_cfg(self):
if len(self.forecast_cfg) == 0:
self.list_num_forecast_steps = np.array([0], dtype=np.int32)
self.output_offset = 0
self.forecast_policy = None
self.time_step = np.timedelta64(0, "ms")
return

fsm = self.list_num_forecast_steps[0]
forecast_len = (self.time_step * (fsm + 1)) // self.step_timedelta
perms_len = perms_len - (forecast_len + self.output_offset)
self.output_offset = self.forecast_cfg.get("offset", 0)
self.time_step = self.forecast_cfg.get("time_step", np.timedelta64(0, "ms"))
self.forecast_policy = self.forecast_cfg.get("policy", None)

self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False)
if isinstance(self.forecast_cfg.num_steps, int):
steps = [self.forecast_cfg.num_steps]
else:
steps = self.forecast_cfg.num_steps

self.list_num_forecast_steps = np.array(steps, dtype=np.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be simplified and inlined:

def __init__():
    ...
    forecast_cfg = FORECAST_DEFAULTS | mode_cfg.get("forecast", {})
    self.output_offset = forecast_cfg["offset"]
    self.time_step = forecast_cfg["time_step"]
    self.forecast_policy = forecast_cfg["policy"]
    steps = np.array(forecast_cfg.num_steps, dtype=np.int32).reshape(-1)
    self.list_num_forecast_steps = np.array(steps, dtype=np.int32)
    ...

where FORECAST_DEFAULTS is defined as module level constant:

FORECAST_DEFAULTS = {
    "offset": 0,
    "time_step": np.timedelta64(0, "ms"),
    "policy": None,
    "num_steps": np.array([0], dtype=np.int32),
}

This is shorter, more direct and will avoid creating a new member (self.forecast_cfg) and the side-effect heavy method _init_forecast_cfg

Comment on lines +99 to +100
self.forecast_cfg = mode_cfg.get("forecast", {})
self._init_forecast_cfg()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment on _init_forecast_cfg()

Comment on lines 199 to 262
@@ -200,57 +261,69 @@ def __init__(

self.streams_datasets[stream_info["name"]] += [ds]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for separating this part out. However it would be cool if the initialized streams_datasets are passed as a return value. There is also a second more subtle side effect of this method: it adds some information to the global stream configs (eg. channel names). Right now this is hard to avoid, but I would like to see it mentioned in the docstring of this method.

self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch
self.check_samples()
self.calc_baseperms()
self._init_stream_datasets(cf)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.streams_datasets = self._init_stream_datasets(cf) (see my comment on _init_stream_datasets)


# choose correct num samples
if not self.repeat_data and self.samples_per_mini_epoch:
if self.samples_per_mini_epoch >= available_samples:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not comparable quantities: samples_per_mini_epoch includes the additional factor self.batch_size whereas available samples is the number of possible init_times == batches. But there can be multiple samples, sampled from one init time.

samples_per_mini_epoch reduced to {available_samples} to avoid repeating data. \
Set repeat_data_in_mini_epoch to True if this is undesired."
)
self.samples_per_mini_epoch = available_samples - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see l. 171

perms_len = int(self.index_range.end - self.index_range.start)
perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta)
self.base_perms = np.arange(perms_len)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This methods only purpose is to provide self.base_perms, which itself is only ever used once in reset(). returning here would reduce the namespace of this class.

Comment on lines +278 to +282
if self.repeat_data and len(perms) < self.samples_per_mini_epoch:
perms = np.tile(perms, self.samples_per_mini_epoch // len(perms))
filler = self.rng.choice(
perms,
size=self.samples_per_mini_epoch - len(perms),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment on calc_baseperms: perms holds batch indices for the entire mini epoch. The length is therefore off by a factor of self.batch_size from self.samples_per_mini_epoch.

@github-project-automation github-project-automation Bot moved this to In Progress in WeatherGen-dev Apr 13, 2026
@github-actions github-actions Bot added the infra Issues related to infrastructure label Apr 14, 2026
@enssow
Copy link
Copy Markdown
Contributor Author

enssow commented Apr 21, 2026

Note: adjusted to let samples_per_mini_epoch=1 if available_samples=1, otherwise reutrn error like:

uv run inference --from-run-id e3u8p2bn --options test_config.samples_per_mini_epoch=10 test_config.output.num_samples=10 test_config.forecast.num_steps=5 test_config.forecast.offset=1 test_config.start_date=2023-01-10 test_config.end_date=2023-01-12

Samples will be repeated within the time range
0it [00:00, ?it/s]/users/sowens/WeatherGenerator/src/weathergen/train/trainer.py:785: RuntimeWarning: Mean of empty slice
  {np.nanmean(avg_loss)}"""
validation (xjazf2pi) : 000 : 
                        nan


0it [00:00, ?it/s]
FAIL!

@enssow
Copy link
Copy Markdown
Contributor Author

enssow commented Apr 21, 2026

Testing complete, with the fooolowing command uv run inference --from-run-id e3u8p2bn --options test_config.samples_per_mini_epoch=10 test_config.output.num_samples=10 test_config.forecast.num_steps=5 test_config.forecast.offset=1
using the extra options of:

test_config.start_date test_config.end_date result
"2023-01-10" "2023-01-11" AssertionError: There is an insufficient date range to accomodate any number of samples or forecast steps
"2023-01-10" "2023-01-13" Sufficient available samples in the time range specified; checked no duplicate samples
"2023-01-10" "2023-01-12" There are only 1 available_samples, samples_per_mini_epoch reduced to 1 to avoid repeating data. Set repeat_data_in_mini_epoch to True if this is undesired.

@enssow enssow merged commit 18c09b3 into ecmwf:develop Apr 21, 2026
5 of 6 checks passed
@github-project-automation github-project-automation Bot moved this from In Progress to Done in WeatherGen-dev Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working data Anything related to the datasets used in the project infra Issues related to infrastructure model:inference anything related to the inference step (not plotting or score computation). model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Duplicate samples during inference due to different length assumtions in MultiStreamDataReader

5 participants