average_predictor#
Provides a model, which predicts the average of the train data.
- class AveragePredictor(time_series_params: TimeSeriesConfig, model_params: ModelConfig)#
Bases:
ModelDefines a model, which predicts the average of the train data.
Inits the AveragePredictor.
- Parameters:
time_series_params – parameters of the time series that influence the training and archicture of the model.
model_params – configuration for the model.
- property name: str#
Returns the models name.
- Returns:
The models name.
- predict(data: ndarray[Any, dtype[float64]]) ndarray[Any, dtype[float64]]#
Predicts the next timestamps for every row.
- Parameters:
data – 3 dimensional numpy array. First dimension contains time-series. Second dimension contains time steps of a time-series. Third dimension contains the attributes at a single timestep.
- Returns:
A 3 dimensional numpy array, with the predicted values.
Example
>>> import numpy as np >>> from simba_ml.prediction.time_series.models import average_predictor >>> from simba_ml.prediction.time_series.config import time_series_config >>> train = np.array([[[1,2], [1,2]], [[2,5], [2,6]], [[10, 11], [12,12]]])
>>> train.shape (3, 2, 2) >>> model_config = average_predictor.AveragePredictorConfig() >>> ts_config = time_series_config.TimeSeriesConfig() >>> ts_config.input_length = 2 >>> model = average_predictor.AveragePredictor(ts_config, model_config) >>> model.train(train=train) >>> model.avg 5.5 >>> test_input = np.array([[[10, 10], [20, 20]], [[15, 15], [15, 16]]]) >>> print(test_input) [[[10 10] [20 20]] [[15 15] [15 16]]] >>> print(model.predict(test_input)) [[[5.5 5.5]] [[5.5 5.5]]]
- train(train: list[numpy.ndarray[Any, numpy.dtype[numpy.float64]]]) None#
Trains the model with the given data.
- Parameters:
train – data, that can be used for training.
- validate_prediction_input(data: ndarray[Any, dtype[float64]]) None#
Validates the input of the predict function.
- Parameters:
data – a single dataframe containing the input data, where the output will be predicted.
- Raises:
ValueError – if data has incorrect shape (row length does not equal )
- class AveragePredictorConfig(name: str = 'Average Predictor')#
Bases:
ModelConfigDefines the configuration for the DenseNeuralNetwork.