Skip to content

Trainer

pamiq_core.trainer.Trainer

Trainer(
    training_condition_data_user: str | None = None,
    min_buffer_size: int = 0,
    min_new_data_count: int = 0,
)

Bases: ABC, PersistentStateMixin, ThreadEventMixin

Abstract base trainer class.

The run method is called repeatedly in the training thread.

Override the following methods
  • on_training_models_attached: Callback method for when training models are attached to the trainer.
  • on_data_users_attached: Callback method when data_users are attached to the trainer.
  • is_trainable: Return whether the training can be executed.
  • setup: To setup before training starts.
  • train: The training process.
  • teardown: To teardown after training.

Models and data buffers become available after the thread has started.

Initialize a trainer.

PARAMETER DESCRIPTION
training_condition_data_user

Name of the data user to check for trainability. If None, trainer is always trainable.

TYPE: str | None DEFAULT: None

min_buffer_size

Minimum total data points required in the buffer.

TYPE: int DEFAULT: 0

min_new_data_count

Minimum number of new data points required since last training.

TYPE: int DEFAULT: 0

Source code in src/pamiq_core/trainer/base.py
def __init__(
    self,
    training_condition_data_user: str | None = None,
    min_buffer_size: int = 0,
    min_new_data_count: int = 0,
) -> None:
    """Initialize a trainer.

    Args:
        training_condition_data_user: Name of the data user to check for trainability.
            If None, trainer is always trainable.
        min_buffer_size: Minimum total data points required in the buffer.
        min_new_data_count: Minimum number of new data points required since last training.
    """
    super().__init__()
    self._retrieved_model_names: set[str] = set()
    self._training_condition_data_user = training_condition_data_user
    self._min_buffer_size = min_buffer_size
    self._min_new_data_count = min_new_data_count
    self._previous_training_time = float("-inf")

attach_training_models

attach_training_models(training_models: TrainingModelsDict) -> None

Attaches TrainingModelsDict to this trainer.

Source code in src/pamiq_core/trainer/base.py
def attach_training_models(self, training_models: TrainingModelsDict) -> None:
    """Attaches TrainingModelsDict to this trainer."""
    self._training_models = training_models
    self.on_training_models_attached()

on_training_models_attached

on_training_models_attached() -> None

Callback method for when training models are attached to the trainer.

Use :meth:get_training_model to retrieve the model that will be trained.

Source code in src/pamiq_core/trainer/base.py
def on_training_models_attached(self) -> None:
    """Callback method for when training models are attached to the
    trainer.

    Use :meth:`get_training_model` to retrieve the model that will be trained.
    """
    pass

attach_data_users

attach_data_users(data_users: DataUsersDict) -> None

Attaches DataUsersDict to this trainer.

Source code in src/pamiq_core/trainer/base.py
def attach_data_users(self, data_users: DataUsersDict) -> None:
    """Attaches DataUsersDict to this trainer."""
    self._data_users = data_users
    self.on_data_users_attached()

on_data_users_attached

on_data_users_attached() -> None

Callback method when data users are attached to the trainer.

Use :meth:get_data_user to obtain the data user class for dataset.

Source code in src/pamiq_core/trainer/base.py
def on_data_users_attached(self) -> None:
    """Callback method when data users are attached to the trainer.

    Use :meth:`get_data_user` to obtain the data user class for dataset.
    """
    pass

get_training_model

get_training_model(name: str) -> TrainingModel[Any]

Retrieves the training model.

If the specified model includes an inference model, it is automatically synchronized after training.

Source code in src/pamiq_core/trainer/base.py
def get_training_model(self, name: str) -> TrainingModel[Any]:
    """Retrieves the training model.

    If the specified model includes an inference model, it is
    automatically synchronized after training.
    """
    model = self._training_models[name]
    self._retrieved_model_names.add(name)
    return model

get_data_user

get_data_user(name: str) -> DataUser[Any]

Retrieves the data user.

Source code in src/pamiq_core/trainer/base.py
def get_data_user(self, name: str) -> DataUser[Any]:
    """Retrieves the data user."""
    return self._data_users[name]

is_trainable

is_trainable() -> bool

Determines if the training can be executed.

Checks if training can proceed based on data availability when a training condition data user is specified.

RETURNS DESCRIPTION
bool

True if training can be executed, False otherwise.

Source code in src/pamiq_core/trainer/base.py
def is_trainable(self) -> bool:
    """Determines if the training can be executed.

    Checks if training can proceed based on data availability when
    a training condition data user is specified.

    Returns:
        True if training can be executed, False otherwise.
    """
    # If no data user is specified for condition checking, always trainable
    if self._training_condition_data_user is None:
        return True

    data_user = self.get_data_user(self._training_condition_data_user)
    data_user.update()

    trainable = (
        len(data_user) >= self._min_buffer_size
        and data_user.count_data_added_since(self._previous_training_time)
        >= self._min_new_data_count
    )

    if trainable:
        self._previous_training_time = time.time()

    return trainable

setup

setup() -> None

Setup procedure before training starts.

Source code in src/pamiq_core/trainer/base.py
def setup(self) -> None:
    """Setup procedure before training starts."""
    pass

train abstractmethod

train() -> None

Train models.

Please build the models, optimizers, dataset, and other components in this method. This method is called repeatedly.

After this method, :meth:sync_models to be called.

Source code in src/pamiq_core/trainer/base.py
@abstractmethod
def train(self) -> None:
    """Train models.

    Please build the models, optimizers, dataset, and other components in this method.
    This method is called repeatedly.

    After this method, :meth:`sync_models` to be called.
    """

sync_models

sync_models() -> None

Synchronizes params of trained models to inference models.

Source code in src/pamiq_core/trainer/base.py
def sync_models(self) -> None:
    """Synchronizes params of trained models to inference models."""
    for name in self._retrieved_model_names:
        self._training_models[name].sync()

teardown

teardown() -> None

Teardown procedure after training.

Source code in src/pamiq_core/trainer/base.py
def teardown(self) -> None:
    """Teardown procedure after training."""
    pass

run

run() -> bool

Runs the training process if the trainer is trainable.

RETURNS DESCRIPTION
bool

True if training was executed, False if skipped due to conditions not met.

TYPE: bool

Source code in src/pamiq_core/trainer/base.py
def run(self) -> bool:
    """Runs the training process if the trainer is trainable.

    Returns:
        bool: True if training was executed, False if skipped due to conditions not met.
    """
    if not self.is_trainable():
        return False

    self.setup()
    self.train()
    self.sync_models()
    self.teardown()
    return True

save_state

save_state(path: Path) -> None

Save the trainer state to the specified path.

PARAMETER DESCRIPTION
path

Directory path where to save the trainer state.

TYPE: Path

Source code in src/pamiq_core/trainer/base.py
@override
def save_state(self, path: Path) -> None:
    """Save the trainer state to the specified path.

    Args:
        path: Directory path where to save the trainer state.
    """
    path.mkdir()
    (path / "previous_training_time").write_text(
        str(self._previous_training_time), encoding="utf-8"
    )

load_state

load_state(path: Path) -> None

Load the trainer state from the specified path.

PARAMETER DESCRIPTION
path

Directory path from where to load the trainer state.

TYPE: Path

Source code in src/pamiq_core/trainer/base.py
@override
def load_state(self, path: Path) -> None:
    """Load the trainer state from the specified path.

    Args:
        path: Directory path from where to load the trainer state.
    """
    self._previous_training_time = float(
        (path / "previous_training_time").read_text("utf-8")
    )