Write plugins#
You can create plugins in order to adapt the simba_ml framework, e.g. by adding models or metrics. To create a new plugin by creating a module (python file), which has a register function. This function will be called by the framework. In order to register a new model or metric, register it in the according factory.
An example plugin that adds a model, which always predicts zero has the following content:
>>> import dataclasses
>>> import numpy as np
>>> import numpy.typing as npt
>>>
>>> from simba_ml.prediction.time_series.models import model
>>> from simba_ml.prediction.time_series.models import factory
>>>
>>>
>>> @dataclasses.dataclass
... class ZeroPredictorConfig(model.ModelConfig):
... """Defines the configuration for the DenseNeuralNetwork."""
... name: str = "Zero Predictor"
...
...
>>> class ZeroPredictor(model.Model):
... """Defines a model, which predicts the average of the train data."""
...
... def __init__(self, input_length: int, output_length: int, config: ZeroPredictorConfig):
... """Inits the `AveragePredictor`.
...
... Args:
... input_length: the length of the input data.
... output_length: the length of the output data.
... config: the config for the model
... """
... super().__init__(input_length, output_length, config)
...
... def train(self, train: list[npt.NDArray[np.float64]], val: list[npt.NDArray[np.float64]]) -> None:
... pass
...
... def predict(self, data: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
... self.validate_prediction_input(data)
... return np.full((data.shape[0], self.output_length, data.shape[2]), 0.0)
...
...
>>> def register() -> None:
... factory.register(
... "ZeroPredictor",
... ZeroPredictorConfig,
... ZeroPredictor
... )