pytorch_lightning_model#
Provides a model that predicts next timesteps from with a pytorch lightning architecture.
- class ArchitectureParams(units: int = 32, activation: str = 'relu')#
Bases:
objectDefines the parameters for the architecture.
- class PytorchLightningModel(time_series_params: TimeSeriesConfig, model_params: PytorchLightningModelConfig)#
Bases:
ModelDefines a Pytorch Lightning model to predict the next timestamps.
Initializes the model.
- Parameters:
time_series_params – configuration for the time series that influence the training and archicture of the model.
model_params – configuration for the model.
- Raises:
TypeError – if input_length or output_length is not an integer.
- abstract get_model(time_series_params: TimeSeriesConfig, model_params: PytorchLightningModelConfig) LightningModule#
Returns the model.
- Parameters:
time_series_params – configuration for the time series that influence the training and archicture of the model.
model_params – configuration for the model.
- property name: str#
Returns the models name.
- Returns:
The models name.
- predict(data: ndarray[Any, dtype[float64]]) ndarray[Any, dtype[float64]]#
Predicts the next timestamps for every row (time series).
- Parameters:
data – np.array, where each dataframe is a time series.
- Returns:
np.array, where each value is a time series.
- train(train: list[numpy.ndarray[Any, numpy.dtype[numpy.float64]]]) None#
Trains the model with the given data.
- Parameters:
train – training data.
- validate_prediction_input(data: ndarray[Any, dtype[float64]]) None#
Validates the input of the predict function.
- Parameters:
data – a single dataframe containing the input data, where the output will be predicted.
- Raises:
ValueError – if data has incorrect shape (row length does not equal )
- class PytorchLightningModelConfig(name: str = 'Pytorch Lightning Model', architecture_params: ~simba_ml.prediction.time_series.models.pytorch_lightning.pytorch_lightning_model.ArchitectureParams = <factory>, training_params: ~simba_ml.prediction.time_series.models.pytorch_lightning.pytorch_lightning_model.TrainingParams = <factory>, normalize: bool = True)#
Bases:
ModelConfigDefines the configuration for the PytorchLightningModel.
- class TrainingParams(epochs: int = 10, patience: int = 5, batch_size: int = 32, validation_split: float = 0.2, verbose: int = 0)#
Bases:
objectDefines the parameters for the training.