Torch🔥
pamiq_core.torch.TorchTrainingModel ¶
TorchTrainingModel(
model: T,
has_inference_model: bool = True,
inference_thread_only: bool = False,
device: device | str | None = None,
dtype: dtype | None = None,
inference_procedure: InferenceProcedureCallable[T] | str = default_infer_procedure,
pretrained_parameter_file: str | Path | None = None,
compile: bool = False,
)
Bases: TrainingModel[TorchInferenceModel[T]]
PyTorch model wrapper for parallel training and inference.
This class enables efficient multi-threaded operation where training and inference can run in parallel on separate model instances. It manages model synchronization between threads and provides various initialization options.
Type Parameters
T: The type of the PyTorch model (must be nn.Module subclass).
Initialize the TorchTrainingModel.
PARAMETER | DESCRIPTION |
---|---|
model
|
The PyTorch model to wrap for training.
TYPE:
|
has_inference_model
|
Whether to create a separate inference model. If False, inference is not supported.
TYPE:
|
inference_thread_only
|
If True, the same model instance is shared between training and inference (no copying). Use this when the model is only used for inference.
TYPE:
|
device
|
Device to place the model on. If None, keeps the model on its current device.
TYPE:
|
dtype
|
Data type for the model parameters. If specified, converts the model to this dtype.
TYPE:
|
inference_procedure
|
The procedure to use for inference. Can be: - A callable following the InferenceProcedureCallable protocol - A string naming a method on the model class - The default_infer_procedure function (default)
TYPE:
|
pretrained_parameter_file
|
Path to a pre-trained model parameter file. If provided, loads parameters from this file after initialization.
TYPE:
|
compile
|
Whether to compile the model using torch.compile() for potentially better performance. If has_inference_model is True, both models are compiled.
TYPE:
|
RAISES | DESCRIPTION |
---|---|
AttributeError
|
If inference_procedure is a string but doesn't exist in model class attributes. |
ValueError
|
If inference_procedure is a string but doesn't refer to a callable method on the model class. |
Source code in src/pamiq_core/torch/model.py
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
|
sync_impl ¶
Synchronize training model parameters to the inference model.
This method implements an efficient parameter synchronization strategy by swapping model references and copying state dictionaries. It preserves gradients on the training model during the sync operation.
PARAMETER | DESCRIPTION |
---|---|
inference_model
|
The inference model to synchronize parameters to.
TYPE:
|
Note
The models are put in eval mode during sync and returned to train mode afterwards to ensure proper behavior of layers like BatchNorm and Dropout.
Source code in src/pamiq_core/torch/model.py
forward ¶
Forward pass through the training model.
PARAMETER | DESCRIPTION |
---|---|
*args
|
Positional arguments to pass to the model.
TYPE:
|
**kwds
|
Keyword arguments to pass to the model.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Any
|
The output from the model's forward pass. |
Source code in src/pamiq_core/torch/model.py
save_state ¶
Save the model parameters.
PARAMETER | DESCRIPTION |
---|---|
path
|
Base path for saving the model state. The actual file will be saved as "{path}.pt".
TYPE:
|
Source code in src/pamiq_core/torch/model.py
load_state ¶
Load model parameters.
PARAMETER | DESCRIPTION |
---|---|
path
|
Base path for loading the model state. The actual file loaded will be "{path}.pt".
TYPE:
|
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.TorchInferenceModel ¶
Bases: InferenceModel
Thread-safe wrapper for PyTorch models used in inference.
This class provides a thread-safe interface for performing inference with PyTorch models in multi-threaded environments. It uses a lock to ensure that model updates and inference operations don't interfere with each other.
Type Parameters
T: The type of the PyTorch model (must be nn.Module subclass).
Initialize the TorchInferenceModel.
PARAMETER | DESCRIPTION |
---|---|
model
|
A PyTorch model to wrap for thread-safe inference.
TYPE:
|
inference_procedure
|
A callable that defines how to perform inference with the model. It should specify the model as the first argument, followed by additional arguments.
TYPE:
|
Source code in src/pamiq_core/torch/model.py
infer ¶
Perform thread-safe inference with gradient computation disabled.
This method executes the inference procedure with the model while ensuring thread safety through locking and disabling gradient computation for efficiency.
PARAMETER | DESCRIPTION |
---|---|
*args
|
Positional arguments to pass to the inference procedure.
TYPE:
|
**kwds
|
Keyword arguments to pass to the inference procedure.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Any
|
The output from the inference procedure. |
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.TorchTrainer ¶
TorchTrainer(
training_condition_data_user: str | None = None,
min_buffer_size: int = 0,
min_new_data_count: int = 0,
)
Bases: Trainer
Base class for PyTorch model training in pamiq-core.
This trainer integrates PyTorch models with the pamiq-core framework, providing functionality for optimizer configuration, state persistence, and model type validation. It automatically handles the setup and teardown of optimizers and learning rate schedulers during the training process.
Subclasses should implement the configure_optimizers
and train
methods
to define the specific training behavior.
Initialize the PyTorch trainer.
Sets up empty containers for optimizers, schedulers, and their respective states. Actual optimizer and scheduler instances will be created during the setup phase.
PARAMETER | DESCRIPTION |
---|---|
training_condition_data_user
|
Name of the data user to check for trainability. If None, trainer is always trainable.
TYPE:
|
min_buffer_size
|
Minimum total data points required in the buffer.
TYPE:
|
min_new_data_count
|
Minimum number of new data points required since last training.
TYPE:
|
Source code in src/pamiq_core/torch/trainer.py
get_torch_training_model ¶
Get a TorchTrainingModel with type checking. Retrieves a PyTorch model training model by name and validates internal model type.
PARAMETER | DESCRIPTION |
---|---|
name
|
Name of the model to retrieve.
TYPE:
|
module_cls
|
Expected internal module class.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
TorchTrainingModel[T]
|
An instance of TorchTrainingModel with specified model type. |
Raises: TypeError: If the model is not a TorchTrainingModel or internal model is not specified module class.
Source code in src/pamiq_core/torch/trainer.py
create_optimizers
abstractmethod
¶
Create and return optimizers and optional schedulers for training. Implementations should create optimizers for each model being trained, and optionally create learning rate schedulers. Returns: Either: - Dictionary mapping names to optimizers - Tuple containing (optimizers dictionary, schedulers dictionary)
Source code in src/pamiq_core/torch/trainer.py
setup ¶
Set up the training environment.
Initializes optimizers and schedulers by calling the configure_optimizers
method and restores their states if previously saved.
Source code in src/pamiq_core/torch/trainer.py
teardown ¶
Clean up after training.
Keeps the current state of optimizers and schedulers before cleanup.
save_state ¶
Save trainer state to disk.
Persists the states of all optimizers and schedulers to the specified directory path.
PARAMETER | DESCRIPTION |
---|---|
path
|
Directory path where state should be saved
TYPE:
|
Source code in src/pamiq_core/torch/trainer.py
load_state ¶
Load trainer state from disk.
Loads the previously saved states of optimizers and schedulers from the specified directory path.
PARAMETER | DESCRIPTION |
---|---|
path
|
Directory path from where state should be loaded
TYPE:
|
RAISES | DESCRIPTION |
---|---|
ValueError
|
If the path does not exist or is not a directory |
Source code in src/pamiq_core/torch/trainer.py
pamiq_core.torch.get_device ¶
Retrieves the device where the module runs.
PARAMETER | DESCRIPTION |
---|---|
module
|
A module that you want to know which device it runs on.
TYPE:
|
default_device
|
A device to return if any device not found.
TYPE:
|
Returns: A device that the module uses or default_device.
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.default_infer_procedure ¶
Default inference procedure with device placement.
This function automatically moves tensor arguments to the same device as the model before performing inference. Non-tensor arguments are passed through unchanged.
PARAMETER | DESCRIPTION |
---|---|
model
|
The model to infer.
TYPE:
|
*args
|
Positional arguments to pass to the model. Tensors will be moved to the model's device.
TYPE:
|
**kwds
|
Keyword arguments to pass to the model. Tensor values will be moved to the model's device.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Any
|
The output from the model's forward pass. |
Note
When overriding this method, ensure that input tensors are properly sent to the correct device to avoid device mismatch.