Trainer¶
The trainer
module defines how models are trained during system operation. It manages the training process, decides when to train, and handles synchronization between training and inference models.
Basic Concepts¶
Trainers in PAMIQ-Core operate in the training thread, processing data collected during inference. The key responsibilities of a trainer include:
- Deciding when to train: Using data availability conditions to determine if training should occur
- Executing training logic: Implementing the actual learning algorithm
- Synchronizing models: Ensuring updated parameters are available for inference
The training process flows as follows:
- Check if training conditions are met via
is_trainable()
- Run setup via
setup()
- Execute training via
train()
- Synchronize models via
sync_models()
- Clean up resources via
teardown()
Implementing Custom Trainers¶
To implement a custom trainer, you need to extend the Trainer
class and override the train
method:
from pamiq_core import Trainer
from typing import override
class MyTrainer(Trainer):
"""Custom trainer that implements a simple learning algorithm."""
@override
def train(self) -> None:
"""Implement the training logic.
This method is called repeatedly by the training thread when
conditions for training are met.
"""
# Get the model
model = self.get_training_model("policy_model")
# Training logic
for _ in range(10): # 10 optimization steps
# ... implement training algorithm
pass
# Model parameters will be automatically synchronized after training
Training Conditions¶
PAMIQ-Core trainers can be configured with conditions to determine when training should occur:
# Create a trainer that only trains when:
# 1. The "experience" buffer has at least 1000 samples
# 2. At least 100 new samples have been collected since last training
trainer = MyTrainer(
training_condition_data_user="experience",
min_buffer_size=1000,
min_new_data_count=100
)
You can also implement custom training conditions by overriding the is_trainable
method:
@override
def is_trainable(self) -> bool:
"""Determine if training should be executed.
Returns:
True if training should proceed, False otherwise
"""
# Check parent condition first
if not super().is_trainable():
return False
# Add custom condition - only train during daytime
current_hour = datetime.now().hour
return 8 <= current_hour <= 20
Accessing Training Models¶
To access models for training, override the on_training_models_attached
callback:
@override
def on_training_models_attached(self) -> None:
"""Called when training models are attached to the trainer."""
# Store references to models for convenient access
self.policy_model = self.get_training_model("policy")
self.value_model = self.get_training_model("value")
Using Experience Data¶
To access collected data for training, override the on_data_users_attached
callback:
@override
def on_data_users_attached(self) -> None:
"""Called when data users are attached to the trainer."""
# Store references to data users for convenient access
self.experience_data = self.get_data_user("experience")
@override
def train(self) -> None:
"""Train models using collected experience data."""
# Update to get latest data
self.experience_data.update()
# Access the data
data = self.experience_data.get_data()
states = data["state"]
actions = data["action"]
rewards = data["reward"]
# Use the data to train models
# ...
Lifecycle Callbacks¶
PAMIQ-Core trainers have several lifecycle callbacks that you can override:
Setup and Teardown¶
@override
def setup(self) -> None:
"""Called before training starts.
Use this to initialize resources needed for training.
"""
super().setup() # Always call the parent method
# Initialize training resources
self.optimizer = SomeOptimizer(self.policy_model.parameters())
self.batch_size = 64
@override
def teardown(self) -> None:
"""Called after training finishes.
Use this to clean up resources after training.
"""
super().teardown() # Always call the parent method
# Clean up resources
self.optimizer = None
Thread Events¶
As with other PAMIQ-Core components, trainers can respond to system pause and resume events:
@override
def on_paused(self) -> None:
"""Called when the system is paused."""
super().on_paused() # Always call the parent method
# Pause external connections or resources
if hasattr(self, 'external_service'):
self.external_service.pause()
@override
def on_resumed(self) -> None:
"""Called when the system is resumed."""
super().on_resumed() # Always call the parent method
# Resume external connections or resources
if hasattr(self, 'external_service'):
self.external_service.resume()
These event hooks enable proper resource management in response to system state changes.
PyTorch Integration 🔥¶
For training deep learning models, PAMIQ-Core provides specialized trainer classes in the torch
submodule.
For more details on using PyTorch with PAMIQ-Core, see the PyTorch Integration Guide.
API Reference¶
More details, Checkout to the API Reference