[1438] Adding padding to end_date to avoid duplicate samples#1749
[1438] Adding padding to end_date to avoid duplicate samples#1749enssow merged 37 commits intoecmwf:developfrom
Conversation
|
Tested on SANTIS for:
|
|
Thanks @enssow for handling this issue But I still do not know what will be a better strategy; should we pad with available dates or decrease the num_samples to defined range date. What if there is no data after the defined end_date. It will throw an error or return empty tensors. |
|
Can you run some inference with this options I saw some unwanted behaviour there |
…ument (switching to JUWELS)
…to sorcha/dev/1438
|
@enssow : what is the root cause of the length mismatch? |
|
@clessig this is now ready for review, reverted some of the changes per the meeting - should i publicise it on the PR channel or would you prefer to give it a look? thank you in advance! |
|
Expected behaviour: Given a specified date range (either given or using the default test_config) if the user requests more samples than are available (calculated using the length of time required to fit samples seperated by |
|
Tested on SANTIS with varying start, end dates, num_samples and num_steps: |
grassesi
left a comment
There was a problem hiding this comment.
I appreciate very much you breaking up the constructor. However I would like to take the opportunity to also reduce the excessive use of object state to pass around information that can be handled by returns instead. This will make the reset logic much more explicit and easier to follow.
Additionally I think there is a small oversight: When calculating the length of the batch index permutation array and when checking the number of available init_times (and thus number of available batches) you compare against self.samples_per_mini_epoch, which includes the additional factor batch_size.
I have implemented the suggestions here: https://github.com/grassesi/WeatherGenerator/tree/sgrasse/sorcha/1438-fix feel free to merge them if they look good to you. Please tell me if I overlooked anything.
| def _init_forecast_cfg(self): | ||
| if len(self.forecast_cfg) == 0: | ||
| self.list_num_forecast_steps = np.array([0], dtype=np.int32) | ||
| self.output_offset = 0 | ||
| self.forecast_policy = None | ||
| self.time_step = np.timedelta64(0, "ms") | ||
| return | ||
|
|
||
| fsm = self.list_num_forecast_steps[0] | ||
| forecast_len = (self.time_step * (fsm + 1)) // self.step_timedelta | ||
| perms_len = perms_len - (forecast_len + self.output_offset) | ||
| self.output_offset = self.forecast_cfg.get("offset", 0) | ||
| self.time_step = self.forecast_cfg.get("time_step", np.timedelta64(0, "ms")) | ||
| self.forecast_policy = self.forecast_cfg.get("policy", None) | ||
|
|
||
| self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False) | ||
| if isinstance(self.forecast_cfg.num_steps, int): | ||
| steps = [self.forecast_cfg.num_steps] | ||
| else: | ||
| steps = self.forecast_cfg.num_steps | ||
|
|
||
| self.list_num_forecast_steps = np.array(steps, dtype=np.int32) |
There was a problem hiding this comment.
This function can be simplified and inlined:
def __init__():
...
forecast_cfg = FORECAST_DEFAULTS | mode_cfg.get("forecast", {})
self.output_offset = forecast_cfg["offset"]
self.time_step = forecast_cfg["time_step"]
self.forecast_policy = forecast_cfg["policy"]
steps = np.array(forecast_cfg.num_steps, dtype=np.int32).reshape(-1)
self.list_num_forecast_steps = np.array(steps, dtype=np.int32)
...where FORECAST_DEFAULTS is defined as module level constant:
FORECAST_DEFAULTS = {
"offset": 0,
"time_step": np.timedelta64(0, "ms"),
"policy": None,
"num_steps": np.array([0], dtype=np.int32),
}This is shorter, more direct and will avoid creating a new member (self.forecast_cfg) and the side-effect heavy method _init_forecast_cfg
| self.forecast_cfg = mode_cfg.get("forecast", {}) | ||
| self._init_forecast_cfg() |
There was a problem hiding this comment.
see my comment on _init_forecast_cfg()
| @@ -200,57 +261,69 @@ def __init__( | |||
|
|
|||
| self.streams_datasets[stream_info["name"]] += [ds] | |||
There was a problem hiding this comment.
Thanks for separating this part out. However it would be cool if the initialized streams_datasets are passed as a return value. There is also a second more subtle side effect of this method: it adds some information to the global stream configs (eg. channel names). Right now this is hard to avoid, but I would like to see it mentioned in the docstring of this method.
| self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch | ||
| self.check_samples() | ||
| self.calc_baseperms() | ||
| self._init_stream_datasets(cf) |
There was a problem hiding this comment.
self.streams_datasets = self._init_stream_datasets(cf) (see my comment on _init_stream_datasets)
|
|
||
| # choose correct num samples | ||
| if not self.repeat_data and self.samples_per_mini_epoch: | ||
| if self.samples_per_mini_epoch >= available_samples: |
There was a problem hiding this comment.
These are not comparable quantities: samples_per_mini_epoch includes the additional factor self.batch_size whereas available samples is the number of possible init_times == batches. But there can be multiple samples, sampled from one init time.
| samples_per_mini_epoch reduced to {available_samples} to avoid repeating data. \ | ||
| Set repeat_data_in_mini_epoch to True if this is undesired." | ||
| ) | ||
| self.samples_per_mini_epoch = available_samples - 1 |
| perms_len = int(self.index_range.end - self.index_range.start) | ||
| perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta) | ||
| self.base_perms = np.arange(perms_len) | ||
|
|
There was a problem hiding this comment.
This methods only purpose is to provide self.base_perms, which itself is only ever used once in reset(). returning here would reduce the namespace of this class.
| if self.repeat_data and len(perms) < self.samples_per_mini_epoch: | ||
| perms = np.tile(perms, self.samples_per_mini_epoch // len(perms)) | ||
| filler = self.rng.choice( | ||
| perms, | ||
| size=self.samples_per_mini_epoch - len(perms), |
There was a problem hiding this comment.
see my comment on calc_baseperms: perms holds batch indices for the entire mini epoch. The length is therefore off by a factor of self.batch_size from self.samples_per_mini_epoch.
|
Note: adjusted to let samples_per_mini_epoch=1 if available_samples=1, otherwise reutrn error like: |
|
Testing complete, with the fooolowing command
|
Description
TimeWindowHandler doesn't produce enough available forecast initilisation times to choose for samples when run inference on a model trained with$n_{fstep}$ forecast steps and $n_{samples}*dt\geq t_{end} - t_{start}$ .$n_{fstep}=$ $n_{samples}=$ $dt=$ $t_{start}=$ $t_{end}=$
Where
--forecast_steps,--samples,--step_hours,--start,--end(See #1438 and #1085) for more info
This PR provides this padding by working out how many available individual initialisation times there are and adjusting the end of the time window to accomodate that and taking into account the extra time needed to accomodate the number of forecast steps to rollout to
Issue Number
Closes #1438 #2059
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60