Training API Docs#
training.dataindex#
- class pyearthtools.training.dataindex.MLDataIndex(wrapper, *, data_interval, cache=None, prediction_function='predict', prediction_config=None, offsetInterval=False, post_transforms=None, override=False, data_attributes=None, **kwargs)#
pyearthtools.trainingDataIndexUses an underlying ML model to generate data to cache.
Setup ML Data Index from defined wrapper
- Info:
This can be used just like an [Index][pyearthtools.data.indexes] from [pyearthtools.data][pyearthtools.data], so calling or indexing into this object work, as well as supplying transforms.
- Parameters:
wrapper (pyearthtoolsTrainer) – pyearthtoolsTrainer to use to retrieve data
data_interval (tuple) – Resolution that the wrapper operates at, in
TimeDeltaform. e.g. (1, ‘day’)cache (str | Path, optional) – Location to cache outputs, if not supplied don’t cache.
prediction_function (str, optional) – Function to use for prediction
prediction_config (dict, optional) – Configuration if predictions
offsetInterval (bool, optional) – Whether to offset time by interval. Defaults to False.
post_transforms (Transform | TransformCollection | None, optional) – Transforms to apply post generation. Defaults to None.
override (bool, optional) – Override any generated data. Defaults to False.
data_attributes (str | Path | None, optional) – Path to yaml file specifying attributes to set.
**kwargs (dict, optional) – Any keyword arguments to pass to [BaseCacheIndex][pyearthtools.data.BaseCacheIndex]
- property data#
Get Data Pipeline
- get(*args, **kwargs)#
Retrieve the prediction data for the request sought
training.manage#
- class pyearthtools.training.manage.Variables(variables=None, *, order=None, **kwargs)#
Variable management class.
Allows for the specification of categories of variables in an array, and the ordering.
- Provides functions for:
Reordering
Splitting
Adding
Removing
Extracting
Joining
These functions can be run on any incoming data, with a choice of the order associated.
The order is specified as the first letter of the category capitalised.
All arrays are expected to be channel / variable first.
E.g.
>>> Variables(order = 'PFD', prognostics = 10, diagnostics = 5, forcings = 7) ... Variables - (PFD) ... prognostics - 10 ... forcings - 7 ... diagnostics - 5
>>> variables = Variables(order = 'PFD', prognostics = 10, diagnostics = 5, forcings = 7) >>> variables.split(data) ... { ... # Categories of data split accordingly ... } >>> variables.extract(data, 'diagnostics') ... # diagnostics extracted from the data
Setup Variable Manager
- Parameters:
order (Optional[str], optional) –
Order of categories. Uses the first letter capitalised.
If not given, is inferred from order of
kwargsvariables (Optional[Union[Variables, dict[str, list[str] | int]]])
- add(data: D, incoming: tuple[D], category: tuple[str], order: str | None = None) D#
- add(data: D, incoming: D, category: str, order: str | None = None) D
Add
incomingdata into thedata, in the correct spot.Result will be reorded into the ‘correct’ order.
See
.joinfor explict assignment.- Parameters:
data (DATA_TYPES) – Data to add to
incoming (DATA_TYPES) – Data to add
category (str) – Category of data to add
order (Optional[str], optional) – Order of
dataif different, expects no entry forcategory. Defaults to None.
- Returns:
Merged data in the order as specified in
init.- Return type:
(DATA_TYPES)
Examples
>>> vars = Variables(order = 'PFD', prognostics = 10, diagnostics = 5, forcings = 7) >>> forcings_missing = np.ones((15)) >>> forcings = np.zeros((7)) >>> data = vars.add(forcings_missing, forcings, order = 'PD') >>> data.shape (22,)
- property category_names: list[str]#
Get names of categories specified, may not be in order.
- check_category(category)#
Check category, to see if it is specified
- Parameters:
category (str)
- check_order(order)#
Check order to see if all elements are valid
- Parameters:
order (str)
- compare_length(data, order=None, error=False)#
Compare length of data to that of expected.
Raise error if
error.- Parameters:
data (ndarray | Dataset | Tensor | list)
order (str | None)
error (bool)
- Return type:
bool
- extract(data: D, category: tuple[str, ...], order: str | None = None) tuple[D, ...]#
- extract(data: D, category: str, order: str | None = None) D
Extract a category from the
data.If
orderis given,datacan contain only a subset of categories, and will extract correctly.- Parameters:
data (DATA_TYPES) – Data to extract category from.
category (tuple[str, ...] | str) – Category to extract
order (Optional[str], optional) – Order if different to specified. Defaults to None.
- Returns:
Extracted data
- Return type:
(tuple[DATA_TYPES, …] | DATA_TYPES)
- join(**kwargs)#
Join data given in
kwargs, will be ordered based on specified order- Parameters:
kwargs (ndarray | Dataset | Tensor | list)
- Return type:
ndarray | Dataset | Tensor | list
- length(category)#
Get length of categories
- Parameters:
category (str | list[str] | tuple[str, ...])
- Return type:
int
- length_of_order(order)#
Get expected length of categories as listed in
order.- Parameters:
order (str)
- Return type:
int
- names_from_order(order=None, reorder=False)#
Get names of categories from
order- Parameters:
order (Optional[str], optional) – Order to get names from. Defaults to None.
reorder (bool, optional) – Whether to reorder
orderto be in specifiedinitorder. Defaults to False.
- Returns:
Names of categories in
orderor ifreorder‘correct’ order.- Return type:
(tuple[str, …])
- np_slices(order=None)#
Get slice for extraction of data
- Parameters:
order (Optional[str], optional) – Incoming order if different from specified. Defaults to None.
- Returns:
Category to slice pairs
- Return type:
(dict[str, slice])
- remove(data, category, order=None)#
Remove a category of data.
If
orderis given,datacan contain only a subset of categories, and will return in the given order.- Parameters:
data (DATA_TYPES) – Data to remove category from
category (str) – Category to remove
order (Optional[str], optional) – Order if different to specified. Can be subset. Defaults to None.
- Returns:
datawithcategoryremoved. Ifordergiven will be maintained.- Return type:
(DATA_TYPES)
- reorder(data, order=None)#
Reorder incoming data into the originally specified order
- Parameters:
data (D)
order (str | None)
- Return type:
D
- split(data, order=None)#
Split incoming data into the specified categories
- Parameters:
data (DATA_TYPES) – Data to split
order (Optional[str], optional) – Order of data, required if data is an array. Defaults to None.
- Returns:
Data split, with keys based on specified
- Return type:
(dict[str, DATA_TYPES])
- xr_slices(data, order=None)#
Get slices for use in xarray.
Return lists of variable names
Requires data in case variables are given as int
- Parameters:
data (Dataset)
order (str | None)
- Return type:
dict[str, list[str]]
training.data#
- class pyearthtools.training.data.PipelineDataModule(pipelines, train_split=None, valid_split=None)#
Base
PipelineDataModuleget_samplecan be used to retrieve frompipelines, andfake_batch_dimcan be overriden if special batch faking is needed.trainconfigures the pipelines fromtrain_splitandvalidfor validation.Setup
Pipeline’s for use with ML Training- Parameters:
pipelines (dict[str, str | Pipeline | tuple[Pipeline, ...]] | tuple[Pipeline | str, ...] | Pipeline | str) – Pipelines for data retrieval, can be dictionary and/or list/tuple of
Pipelinesor a singlePipelinetrain_split (Optional[Iterator], optional) – Iterator to use for training. Pipelines configured by calling
.train(). Defaults to None.valid_split (Optional[Iterator], optional) – Iterator to use for validation. Pipelines configured by calling
.valid(). Defaults to None.
- check_for_use()#
Check if
datamoduleis ready for use.
- eval()#
Set
Pipeline`s to iterate over `valid_split
- fake_batch_dim(sample)#
Fake batch dim on
sample
- classmethod find_shape(obj)#
Find shape of
obj
- get_sample(idx, *, fake_batch_dim=False)#
Get sample from
pipeline- Parameters:
fake_batch_dim (bool)
- classmethod load(stream, **kwargs)#
Load
PipelineDataModuleconfig- Parameters:
stream (str | Path) – File or dump to load
kwargs (Any) – Updates to default values include in the config.
- Returns:
Loaded PipelineDataModule
- Return type:
- map_function_to_pipelines(function, **kwargs)#
Map a function over
Pipelines- Parameters:
function (Callable[[Pipeline], Any])
- save(path=None)#
Save
PipelineDataModule- Parameters:
path (Optional[str | Path], optional) – File to save to. If not given return save str. Defaults to None.
- Returns:
If
pathis None,PipelineDataModulein save form else None.- Return type:
(Union[None, str])
- train()#
Set
Pipeline`s to iterate over `train_split
- pyearthtools.training.data.default()#
Default DataModules
Basic sampling and batching
- pyearthtools.training.data.save(datamodule, path=None)#
Save
Pipeline- Parameters:
pipeline (pyearthtools.pipeline.Pipeline) – Pipeline to save
path (Optional[FILE], optional) – File to save to. If not given return save str. Defaults to None.
datamodule (PipelineDataModule)
- Returns:
If
pathis None,pipelinein save form else None.- Return type:
(Union[None, str])
- pyearthtools.training.data.load(stream, **kwargs)#
Load
Datamoduleconfig- Parameters:
stream (Union[str, Path]) – File or dump to load
kwargs (Any) – Updates to default values include in the config.
- Returns:
Loaded Pipeline
- Return type:
training.wrapper#
- class pyearthtools.training.wrapper.ModelWrapper(model, data)#
Base Model Wrapper
Defines the interface in which to use a
model, anddatamodule/PipelineConstruct Base model wrapper
modelwill not be recorded in the initialisation by default, set_record_modelto change this behaviour.- Parameters:
model (Any) – Model to use.
data (dict[str, Pipeline | tuple[Pipeline, ...]] | tuple[Pipeline, ...] | Pipeline | PipelineDataModule) – Data to use. If not
PipelineDataModulewill be made into_default_datamodule. Will only then haveget_sample.
- get_sample(idx, *, fake_batch_dim=False)#
Get sample from the
datamodule.- Parameters:
fake_batch_dim (bool)
- abstractmethod load(*args, **kwargs)#
Load model
- property pipelines#
Get pipelines from the
datamodule.
- abstractmethod predict(data, *args, **kwargs)#
Run a forward pass with the model
- Parameters:
data – Data to run prediction with
- abstractmethod save(*args, **kwargs)#
Save model
- property splits#
Training and Validation split as configured by the
datamodule.
- class pyearthtools.training.wrapper.TrainingWrapper(model, data)#
Model wrapper to enable training
Construct Base model wrapper
modelwill not be recorded in the initialisation by default, set_record_modelto change this behaviour.- Parameters:
model (Any) – Model to use.
data (dict[str, Pipeline | tuple[Pipeline, ...]] | tuple[Pipeline, ...] | Pipeline | PipelineDataModule) – Data to use. If not
PipelineDataModulewill be made into_default_datamodule. Will only then haveget_sample.
- class pyearthtools.training.wrapper.Predictor(model, reverse_pipeline)#
Wrapper to enable prediction
- Hooks:
after_predict(prediction) -> prediction: Function executed after data has been reversed from prediction.- Usage:
>>> model = ModelWrapper(MODEL_GOES_HERE, DATA_PIPELINE) >>> predictor = Predictor(model) >>> predictor.predict('2000-01-01T00')
Use a
modelto run a prediction.Retrieves initial conditions for
model.get_sample, so set it’sPipelineaccordingly.- Parameters:
model (ModelWrapper) – Model and Data source to use.
reverse_pipeline (Pipeline | int | str | None) – If not given, will default to using
model.pipelines. Override forPipelineto use on the undo operation. Ifstrorintuse value to index intomodel.pipelines. Useful ifmodel.pipelinesis a dictionary or tuple. Or can bePipelineit self to use. Ifreverse_pipeline.has_source()is True, runreverse_pipeline.undo. otherwise apply pipeline withreverse_pipeline.apply
- pyearthtools.training.wrapper.lightning.Predict#
alias of
LightingPrediction
- class pyearthtools.training.wrapper.lightning.predict.LoggingContext(change=True)#
Quiet lightning warnings
- Parameters:
change (bool)
- pyearthtools.training.wrapper.lightning.Train#
alias of
LightingTraining
- pyearthtools.training.wrapper.lightning.train.get_logger(logger, path, **kwargs)#
Get logger
- Parameters:
logger (str | bool)
path (str)
- pyearthtools.training.wrapper.lightning.train.make_callback(callback, kwargs, **formats)#
Make Lightning callback from
kwargsformatted withformats.- Parameters:
callback (str)
kwargs (dict[str, Any])
- class pyearthtools.training.wrapper.lightning.wrapper.LightningWrapper(model, data, path, trainer_kwargs=None, **kwargs)#
Pytorch Lightning ModelWrapper
- For prediction use
pyearthtools.training.lightning.Predict- For training use
pyearthtools.training.lightning.Train
Base pytorch lightning model wrapper
- Parameters:
model (L.LightningModule) – Lightning Model to use for prediction.
data (dict[str, Pipeline | str | tuple[Pipeline, ...]] | tuple[Pipeline | str , ...] | str | Pipeline | PipelineLightningDataModule) – Pipeline to use to get data. Will be converted into a
PipelineLightningDataModule.path (str | Path) – Root path
trainer_kwargs (Optional[dict[str, Any]], optional) – Kwargs for
L.Trainer. Defaults to None.
- class pyearthtools.training.wrapper.predict.TimeSeriesPredictor(model, reverse_pipeline=None, *, fix_time_dim=True, interval=1, time_dim='time')#
Temporal predictions
Adds
recurrent, which is expected to be implemented by subclass.- Hooks:
prepare_output(prediction) -> prediction:Function executed to prepare model outputs for the inputs.
Usage:
>>> model = ModelWrapper(MODEL_GOES_HERE, DATA_PIPELINE) >>> predictor = TimeSeriesPredictionWrapper(model) >>> predictor.recurrent('2000-01-01T00', steps = 10)
Predict with a
modela time series.- Parameters:
model (ModelWrapper) – Model and Data source to use.
reverse_pipeline (Pipeline | int | str | None) –
Override for
Pipelineto use on the undo operation.If not given, will default to using
model.pipelines.If
strorintuse value to index intomodel.pipelines. Useful ifmodel.pipelinesis a dictionary or tuple.Or can be
Pipelineit self to use.If
reverse_pipeline.has_source()is True, runreverse_pipeline.undo. otherwise apply pipeline withreverse_pipeline.apply
fix_time_dim (bool) – Fix time dimension after prediction.
interval (int | str | TimeDelta) – Interval of temporal predictions, must be passable by
pyearthtools.data.TimeDelta.time_dim (str) – Name of time dimension in undone data.
- fix_time_dim(idx, data, *, offset=1)#
Time dimension is usually wrong after running out, so this attempts to fix it.
Uses
intervalandtime_dimfrom__init__.- Parameters:
idx (Any) – Starting index
data (XR_TYPE) – Data to fix time dimension of
offset (int, optional) – Offset of idx. Defaults to 1.
- Returns:
Data with fixed time dim
- Return type:
(XR_TYPE)
- predict(idx, fake_batch_dim=True, **kwargs)#
Run prediction with
modelwith data fromidx- Parameters:
idx (Any) – Index to get initial conditions from
fake_batch_dim (bool, optional) – Whether to fake the batch dim. Defaults to True.
- Returns:
Prediction data after being run through
reverseandafter_predict.- Return type:
(Any)
- prepare_output(output)#
Hook to prepare output for inputs
- class pyearthtools.training.wrapper.predict.TimeSeriesAutoRecurrentPredictor(model, reverse_pipeline=None, *, fix_time_dim=True, interval=1, time_dim='time', combine='concat', combine_axis=0)#
AutoRecurrent temporal predictions.
Predict with a
modela time series.combineandcombine_axiscan be used to modify how timesteps are combined.If model predictions have a leading time dim, use
concat, or if time dim at 2nd axis, setcombine_axis = 1.If no time dim included, set
combinetostack.If data must be reversed before being combined, set
combine = None. Will be undone, andxr.combine_by_coordsused.Warning
The pipeline that is used to undo the predictions, if
combinemust allow a change in the time dimension, i.e. no squish’s or expand’s on that dim.- Parameters:
model (ModelWrapper) – Model and Data source to use.
reverse_pipeline (Pipeline | int | str | None) – Override for
Pipelineto use on the undo operation. If not given, will default to usingmodel.pipelines. Ifstrorintuse value to index intomodel.pipelines. Useful ifmodel.pipelinesis a dictionary or tuple. Or can bePipelineit self to use. Ifreverse_pipeline.has_source()is True, runreverse_pipeline.undo. otherwise apply pipeline withreverse_pipeline.applyfix_time_dim (bool) – Fix time dimension after prediction.
interval (int | str | TimeDelta) – Interval of temporal predictions, must be passable by
pyearthtools.data.TimeDelta.time_dim (str) – Name of time dimension in undone data.
combine (Optional[Literal['stack', 'concat']]) – How to combine timesteps, either stack on
combine_axisor concat. IfNone, do not combine before undo operation and usexr.combine_by_coordsafter.concatconcatenates on existing axis, whereasstackstacks on new axis.combine_axis (int) – If to
combinewhich axis to combine on. Will remove the batch dim, so 0 is actually 1 with batch dim included.
- recurrent(idx, steps, *, fake_batch_dim=True, verbose=False)#
Predict autorecurrently
Requires model
inputs==outputs @ t+1Runs for n`steps` ahead, feeding model outputs back in to predict at the next step.
- Parameters:
idx (Any) – Index to get initial conditions at
steps (int) – Number of steps to roll out for. Model iterations
fake_batch_dim (bool, optional) – Fake batch dim when getting a sample of data. Defaults to True.
verbose (bool, optional) – Show progress. Defaults to False.
- Returns:
Combined temporal data
- Return type:
(Any)
- class pyearthtools.training.wrapper.predict.TimeSeriesManagedPredictor(model, variable_manager, output_order, reverse_pipeline=None, *, input_order=None, variable_axis=0, take_missing_from_input=False, fix_time_dim=True, interval=1, time_dim='time', combine='concat', combine_axis=1, **extra_pipelines)#
AutoRecurrent prediction where output != input.
Uses
Variablesto manage data shape, and can either retrieve missing data fromPipelinesor take from input. If nottake_missing_from_input, expectsmodel.datamodule.pipelinesto be a dictionary, andvariable_managerto use the same names.If
datamodulereturns data,take_missing_from_inputmust be True, as data cannot be retrieved otherwise.Examples
Say a model takes 10 prognostics, 4 forcings, and predicts only the prognostics
`python variable_manager = Variables(prognostics = 10, forcings = 4) # model has pipeline datamodule as a dictionary with `prognostics` and `forcings`. predictor = TimeSeriesManagedRecurrent(model, variable_manager, output_order = 'P', reverse_pipeline = 'prognostics') predictor.recurrent('2000-01-01T00', 10) `If diagnostics are given back by the model, and not given in the inputs.If
reverse_pipelineis not given andpipelinesdata is not a dictionary, put missing data at the end of the order.- ## Note:
If diagnostic type variables are returned, it is unlikely that
reverse_pipelinereferencing an input pipeline will work, so it is best to pass in a new pipeline built to undo model outputs.
AutoRecurrent predictions where output != input.
Expects
model.datamodule.pipelinesto be a dictionary, andvariable_managerto use the same names. Based onoutput_orderfinds the missing data needed for a prediction, and queries thedatamodulefor it iftake_missing_from_inputis False, otherwise pull from input.combine_axisis used to identify number of time steps predicted in one pass of the model.- Parameters:
model (ModelWrapper) – Model and Data source to use.
variable_manager (Variables) – Variable manager, used to extract components from
outputdata, andinputifdatamoduleis not a dictionary.output_order (str) – Order of output for use with
variable_manager. E.g. variable_manager.split(model_output, output_order) If model outputs inputs, and diagnostics,output_orderwould beID.input_order (str, Optional) – Override for order of input data, if incoming data is not a dictionary. If not given, and
incoming datais array will use default order fromvariable_manager. Defaults to None.variable_axis (int, Optional) – Axis of tensor of variables. Used to ensure separation of according to
output_order. Only used if model returns a tensor. Defaults to 0.take_missing_from_input (bool) – Whether to take missing data from the input. Defaults to False.
extra_pipelines (Pipeline, optional) – Extra pipelines to use for missing data retrieval instead of
datamoduleiftake_missing_from_inputis False. Expected to have the same names asvariable_manageranddatamodule. Defaults to {}.reverse_pipeline (Pipeline | str | None)
fix_time_dim (bool)
interval (int | str | TimeDelta)
time_dim (str)
combine (None | Literal['stack'] | Literal['concat'])
combine_axis (int)
See
TimeSeriesAutoRecurrentfor docs for the rest of the args.- recurrent(idx, steps, *, fake_batch_dim=True, verbose=False)#
Predict autorecurrently
outputs do not have to equal inputs.
Will split the outputs based on
output_order, find missing keys, and get frompipelinesorinput.Can be used with
datamodulesthat return dictionaries or data,If
modelreturns a dictionary, will look for a keypredictionfor predictions to pass to outputs.- Parameters:
idx (Any) – Index to get initial conditions at
steps (int) – Number of steps to roll out for. Model iterations
fake_batch_dim (bool, optional) – Fake batch dim when getting a sample of data. Defaults to True.
verbose (bool, optional) – Show progress. Defaults to False.
- Returns:
Combined temporal data
- Return type:
(Any)
- class pyearthtools.training.wrapper.predict.ManualTimeSeriesPredictor(model, reverse_pipeline=None, *, fix_time_dim=True, interval=1, time_dim='time')#
Interface for TimeSeries prediction in which the
modelitself handles all of the recurrence.Predict with a
modela time series.- Parameters:
model (ModelWrapper) – Model and Data source to use.
reverse_pipeline (Pipeline | int | str | None) –
Override for
Pipelineto use on the undo operation.If not given, will default to using
model.pipelines.If
strorintuse value to index intomodel.pipelines. Useful ifmodel.pipelinesis a dictionary or tuple.Or can be
Pipelineit self to use.If
reverse_pipeline.has_source()is True, runreverse_pipeline.undo. otherwise apply pipeline withreverse_pipeline.apply
fix_time_dim (bool) – Fix time dimension after prediction.
interval (int | str | TimeDelta) – Interval of temporal predictions, must be passable by
pyearthtools.data.TimeDelta.time_dim (str) – Name of time dimension in undone data.