Training API Docs#

training.dataindex#

class pyearthtools.training.dataindex.MLDataIndex(wrapper, *, data_interval, cache=None, prediction_function='predict', prediction_config=None, offsetInterval=False, post_transforms=None, override=False, data_attributes=None, **kwargs)#

pyearthtools.training DataIndex

Uses an underlying ML model to generate data to cache.

Setup ML Data Index from defined wrapper

Info:

This can be used just like an [Index][pyearthtools.data.indexes] from [pyearthtools.data][pyearthtools.data], so calling or indexing into this object work, as well as supplying transforms.

Parameters:
  • wrapper (pyearthtoolsTrainer) – pyearthtoolsTrainer to use to retrieve data

  • data_interval (tuple) – Resolution that the wrapper operates at, in TimeDelta form. e.g. (1, ‘day’)

  • cache (str | Path, optional) – Location to cache outputs, if not supplied don’t cache.

  • prediction_function (str, optional) – Function to use for prediction

  • prediction_config (dict, optional) – Configuration if predictions

  • offsetInterval (bool, optional) – Whether to offset time by interval. Defaults to False.

  • post_transforms (Transform | TransformCollection | None, optional) – Transforms to apply post generation. Defaults to None.

  • override (bool, optional) – Override any generated data. Defaults to False.

  • data_attributes (str | Path | None, optional) – Path to yaml file specifying attributes to set.

  • **kwargs (dict, optional) – Any keyword arguments to pass to [BaseCacheIndex][pyearthtools.data.BaseCacheIndex]

property data#

Get Data Pipeline

get(*args, **kwargs)#

Retrieve the prediction data for the request sought

offset_time(time)#

Offset the time given

Controlled by how the init args are set. If offsetInterval is a bool and True, offset by interval Otherwise offset by offsetInterval.

Parameters:

time (str | Petdt) – Time to offset

Returns:

Offset time

Return type:

(Petdt)

training.manage#

class pyearthtools.training.manage.Variables(variables=None, *, order=None, **kwargs)#

Variable management class.

Allows for the specification of categories of variables in an array, and the ordering.

Provides functions for:
  • Reordering

  • Splitting

  • Adding

  • Removing

  • Extracting

  • Joining

These functions can be run on any incoming data, with a choice of the order associated.

The order is specified as the first letter of the category capitalised.

All arrays are expected to be channel / variable first.

E.g.

>>> Variables(order = 'PFD', prognostics = 10, diagnostics = 5, forcings = 7)
...    Variables - (PFD)
...    prognostics - 10
...    forcings - 7
...    diagnostics - 5
>>> variables = Variables(order = 'PFD', prognostics = 10, diagnostics = 5, forcings = 7)
>>> variables.split(data)
... {
...     # Categories of data split accordingly
... }
>>> variables.extract(data, 'diagnostics')
... # diagnostics extracted from the data

Setup Variable Manager

Parameters:
  • order (Optional[str], optional) –

    Order of categories. Uses the first letter capitalised.

    If not given, is inferred from order of kwargs

  • variables (Optional[Union[Variables, dict[str, list[str] | int]]])

add(data: D, incoming: tuple[D], category: tuple[str], order: str | None = None) D#
add(data: D, incoming: D, category: str, order: str | None = None) D

Add incoming data into the data, in the correct spot.

Result will be reorded into the ‘correct’ order.

See .join for explict assignment.

Parameters:
  • data (DATA_TYPES) – Data to add to

  • incoming (DATA_TYPES) – Data to add

  • category (str) – Category of data to add

  • order (Optional[str], optional) – Order of data if different, expects no entry for category. Defaults to None.

Returns:

Merged data in the order as specified in init.

Return type:

(DATA_TYPES)

Examples

>>> vars = Variables(order = 'PFD', prognostics = 10, diagnostics = 5, forcings = 7)
>>> forcings_missing = np.ones((15))
>>> forcings = np.zeros((7))
>>> data = vars.add(forcings_missing, forcings, order = 'PD')
>>> data.shape
    (22,)
property category_names: list[str]#

Get names of categories specified, may not be in order.

check_category(category)#

Check category, to see if it is specified

Parameters:

category (str)

check_order(order)#

Check order to see if all elements are valid

Parameters:

order (str)

compare_length(data, order=None, error=False)#

Compare length of data to that of expected.

Raise error if error.

Parameters:
  • data (ndarray | Dataset | Tensor | list)

  • order (str | None)

  • error (bool)

Return type:

bool

extract(data: D, category: tuple[str, ...], order: str | None = None) tuple[D, ...]#
extract(data: D, category: str, order: str | None = None) D

Extract a category from the data.

If order is given, data can contain only a subset of categories, and will extract correctly.

Parameters:
  • data (DATA_TYPES) – Data to extract category from.

  • category (tuple[str, ...] | str) – Category to extract

  • order (Optional[str], optional) – Order if different to specified. Defaults to None.

Returns:

Extracted data

Return type:

(tuple[DATA_TYPES, …] | DATA_TYPES)

join(**kwargs)#

Join data given in kwargs, will be ordered based on specified order

Parameters:

kwargs (ndarray | Dataset | Tensor | list)

Return type:

ndarray | Dataset | Tensor | list

length(category)#

Get length of categories

Parameters:

category (str | list[str] | tuple[str, ...])

Return type:

int

length_of_order(order)#

Get expected length of categories as listed in order.

Parameters:

order (str)

Return type:

int

names_from_order(order=None, reorder=False)#

Get names of categories from order

Parameters:
  • order (Optional[str], optional) – Order to get names from. Defaults to None.

  • reorder (bool, optional) – Whether to reorder order to be in specified init order. Defaults to False.

Returns:

Names of categories in order or if reorder ‘correct’ order.

Return type:

(tuple[str, …])

np_slices(order=None)#

Get slice for extraction of data

Parameters:

order (Optional[str], optional) – Incoming order if different from specified. Defaults to None.

Returns:

Category to slice pairs

Return type:

(dict[str, slice])

remove(data, category, order=None)#

Remove a category of data.

If order is given, data can contain only a subset of categories, and will return in the given order.

Parameters:
  • data (DATA_TYPES) – Data to remove category from

  • category (str) – Category to remove

  • order (Optional[str], optional) – Order if different to specified. Can be subset. Defaults to None.

Returns:

data with category removed. If order given will be maintained.

Return type:

(DATA_TYPES)

reorder(data, order=None)#

Reorder incoming data into the originally specified order

Parameters:
  • data (D)

  • order (str | None)

Return type:

D

split(data, order=None)#

Split incoming data into the specified categories

Parameters:
  • data (DATA_TYPES) – Data to split

  • order (Optional[str], optional) – Order of data, required if data is an array. Defaults to None.

Returns:

Data split, with keys based on specified

Return type:

(dict[str, DATA_TYPES])

xr_slices(data, order=None)#

Get slices for use in xarray.

Return lists of variable names

Requires data in case variables are given as int

Parameters:
  • data (Dataset)

  • order (str | None)

Return type:

dict[str, list[str]]

training.data#

class pyearthtools.training.data.PipelineDataModule(pipelines, train_split=None, valid_split=None)#

Base PipelineDataModule

get_sample can be used to retrieve from pipelines, and fake_batch_dim can be overriden if special batch faking is needed.

train configures the pipelines from train_split and valid for validation.

Setup Pipeline’s for use with ML Training

Parameters:
  • pipelines (dict[str, str | Pipeline | tuple[Pipeline, ...]] | tuple[Pipeline | str, ...] | Pipeline | str) – Pipelines for data retrieval, can be dictionary and/or list/tuple of Pipelines or a single Pipeline

  • train_split (Optional[Iterator], optional) – Iterator to use for training. Pipelines configured by calling .train(). Defaults to None.

  • valid_split (Optional[Iterator], optional) – Iterator to use for validation. Pipelines configured by calling .valid(). Defaults to None.

check_for_use()#

Check if datamodule is ready for use.

eval()#

Set Pipeline`s to iterate over `valid_split

fake_batch_dim(sample)#

Fake batch dim on sample

classmethod find_shape(obj)#

Find shape of obj

get_sample(idx, *, fake_batch_dim=False)#

Get sample from pipeline

Parameters:

fake_batch_dim (bool)

classmethod load(stream, **kwargs)#

Load PipelineDataModule config

Parameters:
  • stream (str | Path) – File or dump to load

  • kwargs (Any) – Updates to default values include in the config.

Returns:

Loaded PipelineDataModule

Return type:

(PipelineDataModule)

map_function_to_pipelines(function, **kwargs)#

Map a function over Pipelines

Parameters:

function (Callable[[Pipeline], Any])

save(path=None)#

Save PipelineDataModule

Parameters:

path (Optional[str | Path], optional) – File to save to. If not given return save str. Defaults to None.

Returns:

If path is None, PipelineDataModule in save form else None.

Return type:

(Union[None, str])

train()#

Set Pipeline`s to iterate over `train_split

pyearthtools.training.data.default()#

Default DataModules

  • Basic sampling and batching

pyearthtools.training.data.save(datamodule, path=None)#

Save Pipeline

Parameters:
Returns:

If path is None, pipeline in save form else None.

Return type:

(Union[None, str])

pyearthtools.training.data.load(stream, **kwargs)#

Load Datamodule config

Parameters:
  • stream (Union[str, Path]) – File or dump to load

  • kwargs (Any) – Updates to default values include in the config.

Returns:

Loaded Pipeline

Return type:

(pyearthtools.pipeline.Pipeline)

training.wrapper#

class pyearthtools.training.wrapper.ModelWrapper(model, data)#

Base Model Wrapper

Defines the interface in which to use a model, and datamodule/Pipeline

Construct Base model wrapper

model will not be recorded in the initialisation by default, set _record_model to change this behaviour.

Parameters:
  • model (Any) – Model to use.

  • data (dict[str, Pipeline | tuple[Pipeline, ...]] | tuple[Pipeline, ...] | Pipeline | PipelineDataModule) – Data to use. If not PipelineDataModule will be made into _default_datamodule. Will only then have get_sample.

get_sample(idx, *, fake_batch_dim=False)#

Get sample from the datamodule.

Parameters:

fake_batch_dim (bool)

abstractmethod load(*args, **kwargs)#

Load model

property pipelines#

Get pipelines from the datamodule.

abstractmethod predict(data, *args, **kwargs)#

Run a forward pass with the model

Parameters:

data – Data to run prediction with

abstractmethod save(*args, **kwargs)#

Save model

property splits#

Training and Validation split as configured by the datamodule.

class pyearthtools.training.wrapper.TrainingWrapper(model, data)#

Model wrapper to enable training

Construct Base model wrapper

model will not be recorded in the initialisation by default, set _record_model to change this behaviour.

Parameters:
  • model (Any) – Model to use.

  • data (dict[str, Pipeline | tuple[Pipeline, ...]] | tuple[Pipeline, ...] | Pipeline | PipelineDataModule) – Data to use. If not PipelineDataModule will be made into _default_datamodule. Will only then have get_sample.

class pyearthtools.training.wrapper.Predictor(model, reverse_pipeline)#

Wrapper to enable prediction

Hooks:

after_predict (prediction) -> prediction: Function executed after data has been reversed from prediction.

Usage:
>>> model = ModelWrapper(MODEL_GOES_HERE, DATA_PIPELINE)
>>> predictor = Predictor(model)
>>> predictor.predict('2000-01-01T00')

Use a model to run a prediction.

Retrieves initial conditions for model.get_sample, so set it’s Pipeline accordingly.

Parameters:
  • model (ModelWrapper) – Model and Data source to use.

  • reverse_pipeline (Pipeline | int | str | None) – If not given, will default to using model.pipelines. Override for Pipeline to use on the undo operation. If str or int use value to index into model.pipelines. Useful if model.pipelines is a dictionary or tuple. Or can be Pipeline it self to use. If reverse_pipeline.has_source() is True, run reverse_pipeline.undo. otherwise apply pipeline with reverse_pipeline.apply

pyearthtools.training.wrapper.lightning.Predict#

alias of LightingPrediction

class pyearthtools.training.wrapper.lightning.predict.LoggingContext(change=True)#

Quiet lightning warnings

Parameters:

change (bool)

pyearthtools.training.wrapper.lightning.Train#

alias of LightingTraining

pyearthtools.training.wrapper.lightning.train.get_logger(logger, path, **kwargs)#

Get logger

Parameters:
  • logger (str | bool)

  • path (str)

pyearthtools.training.wrapper.lightning.train.make_callback(callback, kwargs, **formats)#

Make Lightning callback from kwargs formatted with formats.

Parameters:
  • callback (str)

  • kwargs (dict[str, Any])

class pyearthtools.training.wrapper.lightning.wrapper.LightningWrapper(model, data, path, trainer_kwargs=None, **kwargs)#

Pytorch Lightning ModelWrapper

For prediction use

pyearthtools.training.lightning.Predict

For training use

pyearthtools.training.lightning.Train

Base pytorch lightning model wrapper

Parameters:
  • model (L.LightningModule) – Lightning Model to use for prediction.

  • data (dict[str, Pipeline | str | tuple[Pipeline, ...]] | tuple[Pipeline | str , ...] | str | Pipeline | PipelineLightningDataModule) – Pipeline to use to get data. Will be converted into a PipelineLightningDataModule.

  • path (str | Path) – Root path

  • trainer_kwargs (Optional[dict[str, Any]], optional) – Kwargs for L.Trainer. Defaults to None.

class pyearthtools.training.wrapper.predict.TimeSeriesPredictor(model, reverse_pipeline=None, *, fix_time_dim=True, interval=1, time_dim='time')#

Temporal predictions

Adds recurrent, which is expected to be implemented by subclass.

Hooks:
prepare_output (prediction) -> prediction:

Function executed to prepare model outputs for the inputs.

Usage:

>>> model = ModelWrapper(MODEL_GOES_HERE, DATA_PIPELINE)
>>> predictor = TimeSeriesPredictionWrapper(model)
>>> predictor.recurrent('2000-01-01T00', steps = 10)

Predict with a model a time series.

Parameters:
  • model (ModelWrapper) – Model and Data source to use.

  • reverse_pipeline (Pipeline | int | str | None) –

    Override for Pipeline to use on the undo operation.

    • If not given, will default to using model.pipelines.

    • If str or int use value to index into model.pipelines. Useful if model.pipelines is a dictionary or tuple.

    • Or can be Pipeline it self to use.

    • If reverse_pipeline.has_source() is True, run reverse_pipeline.undo. otherwise apply pipeline with reverse_pipeline.apply

  • fix_time_dim (bool) – Fix time dimension after prediction.

  • interval (int | str | TimeDelta) – Interval of temporal predictions, must be passable by pyearthtools.data.TimeDelta.

  • time_dim (str) – Name of time dimension in undone data.

fix_time_dim(idx, data, *, offset=1)#

Time dimension is usually wrong after running out, so this attempts to fix it.

Uses interval and time_dim from __init__.

Parameters:
  • idx (Any) – Starting index

  • data (XR_TYPE) – Data to fix time dimension of

  • offset (int, optional) – Offset of idx. Defaults to 1.

Returns:

Data with fixed time dim

Return type:

(XR_TYPE)

predict(idx, fake_batch_dim=True, **kwargs)#

Run prediction with model with data from idx

Parameters:
  • idx (Any) – Index to get initial conditions from

  • fake_batch_dim (bool, optional) – Whether to fake the batch dim. Defaults to True.

Returns:

Prediction data after being run through reverse and after_predict.

Return type:

(Any)

prepare_output(output)#

Hook to prepare output for inputs

class pyearthtools.training.wrapper.predict.TimeSeriesAutoRecurrentPredictor(model, reverse_pipeline=None, *, fix_time_dim=True, interval=1, time_dim='time', combine='concat', combine_axis=0)#

AutoRecurrent temporal predictions.

Predict with a model a time series.

combine and combine_axis can be used to modify how timesteps are combined.

If model predictions have a leading time dim, use concat, or if time dim at 2nd axis, set combine_axis = 1.

If no time dim included, set combine to stack.

If data must be reversed before being combined, set combine = None. Will be undone, and xr.combine_by_coords used.

Warning

The pipeline that is used to undo the predictions, if combine must allow a change in the time dimension, i.e. no squish’s or expand’s on that dim.

Parameters:
  • model (ModelWrapper) – Model and Data source to use.

  • reverse_pipeline (Pipeline | int | str | None) – Override for Pipeline to use on the undo operation. If not given, will default to using model.pipelines. If str or int use value to index into model.pipelines. Useful if model.pipelines is a dictionary or tuple. Or can be Pipeline it self to use. If reverse_pipeline.has_source() is True, run reverse_pipeline.undo. otherwise apply pipeline with reverse_pipeline.apply

  • fix_time_dim (bool) – Fix time dimension after prediction.

  • interval (int | str | TimeDelta) – Interval of temporal predictions, must be passable by pyearthtools.data.TimeDelta.

  • time_dim (str) – Name of time dimension in undone data.

  • combine (Optional[Literal['stack', 'concat']]) – How to combine timesteps, either stack on combine_axis or concat. If None, do not combine before undo operation and use xr.combine_by_coords after. concat concatenates on existing axis, whereas stack stacks on new axis.

  • combine_axis (int) – If to combine which axis to combine on. Will remove the batch dim, so 0 is actually 1 with batch dim included.

recurrent(idx, steps, *, fake_batch_dim=True, verbose=False)#

Predict autorecurrently

Requires model inputs == outputs @ t+1

Runs for n`steps` ahead, feeding model outputs back in to predict at the next step.

Parameters:
  • idx (Any) – Index to get initial conditions at

  • steps (int) – Number of steps to roll out for. Model iterations

  • fake_batch_dim (bool, optional) – Fake batch dim when getting a sample of data. Defaults to True.

  • verbose (bool, optional) – Show progress. Defaults to False.

Returns:

Combined temporal data

Return type:

(Any)

class pyearthtools.training.wrapper.predict.TimeSeriesManagedPredictor(model, variable_manager, output_order, reverse_pipeline=None, *, input_order=None, variable_axis=0, take_missing_from_input=False, fix_time_dim=True, interval=1, time_dim='time', combine='concat', combine_axis=1, **extra_pipelines)#

AutoRecurrent prediction where output != input.

Uses Variables to manage data shape, and can either retrieve missing data from Pipelines or take from input. If not take_missing_from_input, expects model.datamodule.pipelines to be a dictionary, and variable_manager to use the same names.

If datamodule returns data, take_missing_from_input must be True, as data cannot be retrieved otherwise.

Examples

Say a model takes 10 prognostics, 4 forcings, and predicts only the prognostics `python variable_manager = Variables(prognostics = 10, forcings = 4) # model has pipeline datamodule as a dictionary with `prognostics` and `forcings`. predictor = TimeSeriesManagedRecurrent(model, variable_manager, output_order = 'P', reverse_pipeline = 'prognostics') predictor.recurrent('2000-01-01T00', 10) ` If diagnostics are given back by the model, and not given in the inputs.

If reverse_pipeline is not given and pipelines data is not a dictionary, put missing data at the end of the order.

## Note:

If diagnostic type variables are returned, it is unlikely that reverse_pipeline referencing an input pipeline will work, so it is best to pass in a new pipeline built to undo model outputs.

AutoRecurrent predictions where output != input.

Expects model.datamodule.pipelines to be a dictionary, and variable_manager to use the same names. Based on output_order finds the missing data needed for a prediction, and queries the datamodule for it if take_missing_from_input is False, otherwise pull from input.

combine_axis is used to identify number of time steps predicted in one pass of the model.

Parameters:
  • model (ModelWrapper) – Model and Data source to use.

  • variable_manager (Variables) – Variable manager, used to extract components from output data, and input if datamodule is not a dictionary.

  • output_order (str) – Order of output for use with variable_manager. E.g. variable_manager.split(model_output, output_order) If model outputs inputs, and diagnostics, output_order would be ID.

  • input_order (str, Optional) – Override for order of input data, if incoming data is not a dictionary. If not given, and incoming data is array will use default order from variable_manager. Defaults to None.

  • variable_axis (int, Optional) – Axis of tensor of variables. Used to ensure separation of according to output_order. Only used if model returns a tensor. Defaults to 0.

  • take_missing_from_input (bool) – Whether to take missing data from the input. Defaults to False.

  • extra_pipelines (Pipeline, optional) – Extra pipelines to use for missing data retrieval instead of datamodule if take_missing_from_input is False. Expected to have the same names as variable_manager and datamodule. Defaults to {}.

  • reverse_pipeline (Pipeline | str | None)

  • fix_time_dim (bool)

  • interval (int | str | TimeDelta)

  • time_dim (str)

  • combine (None | Literal['stack'] | Literal['concat'])

  • combine_axis (int)

See TimeSeriesAutoRecurrent for docs for the rest of the args.

recurrent(idx, steps, *, fake_batch_dim=True, verbose=False)#

Predict autorecurrently

outputs do not have to equal inputs.

Will split the outputs based on output_order, find missing keys, and get from pipelines or input.

Can be used with datamodules that return dictionaries or data,

If model returns a dictionary, will look for a key prediction for predictions to pass to outputs.

Parameters:
  • idx (Any) – Index to get initial conditions at

  • steps (int) – Number of steps to roll out for. Model iterations

  • fake_batch_dim (bool, optional) – Fake batch dim when getting a sample of data. Defaults to True.

  • verbose (bool, optional) – Show progress. Defaults to False.

Returns:

Combined temporal data

Return type:

(Any)

class pyearthtools.training.wrapper.predict.ManualTimeSeriesPredictor(model, reverse_pipeline=None, *, fix_time_dim=True, interval=1, time_dim='time')#

Interface for TimeSeries prediction in which the model itself handles all of the recurrence.

Predict with a model a time series.

Parameters:
  • model (ModelWrapper) – Model and Data source to use.

  • reverse_pipeline (Pipeline | int | str | None) –

    Override for Pipeline to use on the undo operation.

    • If not given, will default to using model.pipelines.

    • If str or int use value to index into model.pipelines. Useful if model.pipelines is a dictionary or tuple.

    • Or can be Pipeline it self to use.

    • If reverse_pipeline.has_source() is True, run reverse_pipeline.undo. otherwise apply pipeline with reverse_pipeline.apply

  • fix_time_dim (bool) – Fix time dimension after prediction.

  • interval (int | str | TimeDelta) – Interval of temporal predictions, must be passable by pyearthtools.data.TimeDelta.

  • time_dim (str) – Name of time dimension in undone data.