pytorch_lightning_model#

Provides a model that predicts next timesteps from with a pytorch lightning architecture.

class ArchitectureParams(units: int = 32, activation: str = 'relu')#

Bases: object

Defines the parameters for the architecture.

class PytorchLightningModel(time_series_params: TimeSeriesConfig, model_params: PytorchLightningModelConfig)#

Bases: Model

Defines a Pytorch Lightning model to predict the next timestamps.

Initializes the model.

Parameters:
  • time_series_params – configuration for the time series that influence the training and archicture of the model.

  • model_params – configuration for the model.

Raises:

TypeError – if input_length or output_length is not an integer.

abstract get_model(time_series_params: TimeSeriesConfig, model_params: PytorchLightningModelConfig) LightningModule#

Returns the model.

Parameters:
  • time_series_params – configuration for the time series that influence the training and archicture of the model.

  • model_params – configuration for the model.

property name: str#

Returns the models name.

Returns:

The models name.

predict(data: ndarray[Any, dtype[float64]]) ndarray[Any, dtype[float64]]#

Predicts the next timestamps for every row (time series).

Parameters:

data – np.array, where each dataframe is a time series.

Returns:

np.array, where each value is a time series.

train(train: list[numpy.ndarray[Any, numpy.dtype[numpy.float64]]]) None#

Trains the model with the given data.

Parameters:

train – training data.

validate_prediction_input(data: ndarray[Any, dtype[float64]]) None#

Validates the input of the predict function.

Parameters:

data – a single dataframe containing the input data, where the output will be predicted.

Raises:

ValueError – if data has incorrect shape (row length does not equal )

class PytorchLightningModelConfig(name: str = 'Pytorch Lightning Model', architecture_params: ~simba_ml.prediction.time_series.models.pytorch_lightning.pytorch_lightning_model.ArchitectureParams = <factory>, training_params: ~simba_ml.prediction.time_series.models.pytorch_lightning.pytorch_lightning_model.TrainingParams = <factory>, normalize: bool = True)#

Bases: ModelConfig

Defines the configuration for the PytorchLightningModel.

class TrainingParams(epochs: int = 10, patience: int = 5, batch_size: int = 32, validation_split: float = 0.2, verbose: int = 0)#

Bases: object

Defines the parameters for the training.