last_value_predictor#
Provides a model, which predicts the last given input value.
- class LastValuePredictor(time_series_params: TimeSeriesConfig, model_params: ModelConfig)#
Bases:
ModelDefines a model, which predicts the previous value.
Inits the model.
- Parameters:
time_series_params – Time-series parameters that affect the training and architecture of models
model_params – configuration for the model.
- Raises:
TypeError – if input_length or output_length is not an integer.
- property name: str#
Returns the models name.
- Returns:
The models name.
- predict(data: ndarray[Any, dtype[float64]]) ndarray[Any, dtype[float64]]#
Predicts the next timestamps for every row.
- Parameters:
data – 3 dimensional numpy array. First dimension contains time-series. Second dimension contains time steps of a time-series. Third dimension contains the attributes at a single timestep.
- Returns:
A 3 dimensional numpy array, with the predicted values.
Example
>>> import numpy as np >>> from simba_ml.prediction.time_series.models import last_value_predictor >>> from simba_ml.prediction.time_series.config import time_series_config >>> train = np.array([[[1,2], [1,2]], [[2,5], [2,6]], [[10, 11], [12,12]]]) >>> train.shape (3, 2, 2) >>> model_config = last_value_predictor.LastValuePredictorConfig() >>> ts_config = time_series_config.TimeSeriesConfig() >>> ts_config.input_length = 2 >>> model = last_value_predictor.LastValuePredictor(ts_config, model_config) >>> model.train(train=train) >>> test_input = np.array([[[10, 10], [20, 20]], [[15, 15], [15, 16]]]) >>> print(test_input) [[[10 10] [20 20]] [[15 15] [15 16]]] >>> print(model.predict(test_input)) [[[20 20]] [[15 16]]]
- train(train: list[numpy.ndarray[Any, numpy.dtype[numpy.float64]]]) None#
Trains the model with the given data.
- Parameters:
train – data, that can be used for training.
- validate_prediction_input(data: ndarray[Any, dtype[float64]]) None#
Validates the input of the predict function.
- Parameters:
data – a single dataframe containing the input data, where the output will be predicted.
- Raises:
ValueError – if data has incorrect shape (row length does not equal )
- class LastValuePredictorConfig(name: str = 'Last Value Predictor')#
Bases:
ModelConfigDefines the configuration for the DenseNeuralNetwork.