Write plugins#

You can create plugins in order to adapt the simba_ml framework, e.g. by adding models or metrics. To create a new plugin by creating a module (python file), which has a register function. This function will be called by the framework. In order to register a new model or metric, register it in the according factory.

An example plugin that adds a model, which always predicts zero has the following content:

>>> import dataclasses
>>> import numpy as np
>>> import numpy.typing as npt
>>>
>>> from simba_ml.prediction.time_series.models import model
>>> from simba_ml.prediction.time_series.models import factory
>>>
>>>
>>> @dataclasses.dataclass
... class ZeroPredictorConfig(model.ModelConfig):
...     """Defines the configuration for the DenseNeuralNetwork."""
...     name: str = "Zero Predictor"
...
...
>>> class ZeroPredictor(model.Model):
...     """Defines a model, which predicts the average of the train data."""
...
...     def __init__(self, input_length: int, output_length: int, config: ZeroPredictorConfig):
...         """Inits the `AveragePredictor`.
...
...         Args:
...             input_length: the length of the input data.
...             output_length: the length of the output data.
...             config: the config for the model
...         """
...         super().__init__(input_length, output_length, config)
...
...     def train(self, train: list[npt.NDArray[np.float64]], val: list[npt.NDArray[np.float64]]) -> None:
...         pass
...
...     def predict(self, data: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
...         self.validate_prediction_input(data)
...         return np.full((data.shape[0], self.output_length, data.shape[2]), 0.0)
...
...
>>> def register() -> None:
...     factory.register(
...         "ZeroPredictor",
...         ZeroPredictorConfig,
...         ZeroPredictor
...     )