Torch🔥
pamiq_core.torch.TorchAgent ¶
Bases: Agent[O, A]
Agent class specialized for PyTorch models.
This class extends the base Agent class to provide type-safe access to PyTorch inference models.
CLASS TYPE PARAMETER | DESCRIPTION |
---|---|
O
|
The observation type.
|
A
|
The action type.
|
Source code in src/pamiq_core/interaction/agent.py
get_torch_inference_model ¶
Retrieve a TorchInferenceModel with type checking.
This method retrieves an inference model by name and verifies that it is a TorchInferenceModel instance containing the expected PyTorch module type.
PARAMETER | DESCRIPTION |
---|---|
name
|
The name of the inference model to retrieve.
TYPE:
|
module_cls
|
The expected PyTorch module class.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
TorchInferenceModel[T]
|
A TorchInferenceModel instance containing a model of the specified type. |
RAISES | DESCRIPTION |
---|---|
ValueError
|
If the retrieved model is not a TorchInferenceModel instance. |
TypeError
|
If the internal model is not an instance of module_cls. |
KeyError
|
If no model with the specified name exists. |
Source code in src/pamiq_core/torch/agent.py
pamiq_core.torch.TorchTrainingModel ¶
TorchTrainingModel(
model: T,
has_inference_model: bool = True,
inference_thread_only: bool = False,
device: device | str | None = None,
dtype: dtype | None = None,
inference_procedure: InferenceProcedureCallable[T] | str = default_infer_procedure,
pretrained_parameter_file: str | Path | None = None,
compile: bool = False,
)
Bases: TrainingModel[TorchInferenceModel[T]]
PyTorch model wrapper for parallel training and inference.
This class enables efficient multi-threaded operation where training and inference can run in parallel on separate model instances. It manages model synchronization between threads and provides various initialization options.
CLASS TYPE PARAMETER | DESCRIPTION |
---|---|
T
|
The type of the PyTorch model (must be nn.Module subclass).
BOUND:
|
Initialize the TorchTrainingModel.
PARAMETER | DESCRIPTION |
---|---|
model
|
The PyTorch model to wrap for training.
TYPE:
|
has_inference_model
|
Whether to create a separate inference model. If False, inference is not supported.
TYPE:
|
inference_thread_only
|
If True, the same model instance is shared between training and inference (no copying). Use this when the model is only used for inference.
TYPE:
|
device
|
Device to place the model on. If None, keeps the model on its current device.
TYPE:
|
dtype
|
Data type for the model parameters. If specified, converts the model to this dtype.
TYPE:
|
inference_procedure
|
The procedure to use for inference. Can be: - A callable following the InferenceProcedureCallable protocol - A string naming a method on the model class - The default_infer_procedure function (default)
TYPE:
|
pretrained_parameter_file
|
Path to a pre-trained model parameter file. If provided, loads parameters from this file after initialization.
TYPE:
|
compile
|
Whether to compile the model using torch.compile() for potentially better performance. If has_inference_model is True, both models are compiled.
TYPE:
|
RAISES | DESCRIPTION |
---|---|
AttributeError
|
If inference_procedure is a string but doesn't exist in model class attributes. |
ValueError
|
If inference_procedure is a string but doesn't refer to a callable method on the model class. |
Source code in src/pamiq_core/torch/model.py
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
|
sync_impl ¶
Synchronize training model parameters to the inference model.
This method implements an efficient parameter synchronization strategy by swapping model references and copying state dictionaries. It preserves gradients on the training model during the sync operation.
PARAMETER | DESCRIPTION |
---|---|
inference_model
|
The inference model to synchronize parameters to.
TYPE:
|
Note
The models are put in eval mode during sync and returned to train mode afterwards to ensure proper behavior of layers like BatchNorm and Dropout.
Source code in src/pamiq_core/torch/model.py
forward ¶
Forward pass through the training model.
PARAMETER | DESCRIPTION |
---|---|
*args
|
Positional arguments to pass to the model.
TYPE:
|
**kwds
|
Keyword arguments to pass to the model.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Any
|
The output from the model's forward pass. |
Source code in src/pamiq_core/torch/model.py
save_state ¶
Save the model parameters.
PARAMETER | DESCRIPTION |
---|---|
path
|
Base path for saving the model state. The actual file will be saved as "{path}.pt".
TYPE:
|
Source code in src/pamiq_core/torch/model.py
load_state ¶
Load model parameters.
PARAMETER | DESCRIPTION |
---|---|
path
|
Base path for loading the model state. The actual file loaded will be "{path}.pt".
TYPE:
|
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.TorchInferenceModel ¶
Bases: InferenceModel
Thread-safe wrapper for PyTorch models used in inference.
This class provides a thread-safe interface for performing inference with PyTorch models in multi-threaded environments. It uses a lock to ensure that model updates and inference operations don't interfere with each other.
CLASS TYPE PARAMETER | DESCRIPTION |
---|---|
T
|
The type of the PyTorch model (must be nn.Module subclass).
BOUND:
|
Initialize the TorchInferenceModel.
PARAMETER | DESCRIPTION |
---|---|
model
|
A PyTorch model to wrap for thread-safe inference.
TYPE:
|
inference_procedure
|
A callable that defines how to perform inference with the model. It should specify the model as the first argument, followed by additional arguments.
TYPE:
|
Source code in src/pamiq_core/torch/model.py
infer ¶
Perform thread-safe inference with gradient computation disabled.
This method executes the inference procedure with the model while ensuring thread safety through locking and disabling gradient computation for efficiency.
PARAMETER | DESCRIPTION |
---|---|
*args
|
Positional arguments to pass to the inference procedure.
TYPE:
|
**kwds
|
Keyword arguments to pass to the inference procedure.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Any
|
The output from the inference procedure. |
Source code in src/pamiq_core/torch/model.py
unwrap ¶
Get a context manager for direct access to the underlying model.
This method returns a context manager that provides thread-safe direct access to the raw PyTorch model. This is useful when you need to perform operations that are not exposed through the standard inference interface.
PARAMETER | DESCRIPTION |
---|---|
inference_mode
|
If True (default), the context will have torch.inference_mode enabled, which disables gradient computation for better performance.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
UnwrappedContextManager[T]
|
A context manager that yields the raw PyTorch model when entered. |
Example
inference_model = TorchInferenceModel(my_model, procedure) with inference_model.unwrap() as model: ... # Direct access to the model with inference mode enabled ... output = model.some_custom_method(input) ... hidden_state = model.hidden_layer.weight
Note
The context manager ensures thread safety by acquiring a lock for the duration of the context. Avoid holding the context for extended periods to prevent blocking other threads.
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.TorchTrainer ¶
TorchTrainer(
training_condition_data_user: str | None = None,
min_buffer_size: int = 0,
min_new_data_count: int = 0,
)
Bases: Trainer
Base class for PyTorch model training in pamiq-core.
This trainer integrates PyTorch models with the pamiq-core framework, providing functionality for optimizer configuration, state persistence, and model type validation. It automatically handles the setup and teardown of optimizers and learning rate schedulers during the training process.
Subclasses should implement the configure_optimizers
and train
methods
to define the specific training behavior.
Initialize the PyTorch trainer.
Sets up empty containers for optimizers, schedulers, and their respective states. Actual optimizer and scheduler instances will be created during the setup phase.
PARAMETER | DESCRIPTION |
---|---|
training_condition_data_user
|
Name of the data user to check for trainability. If None, trainer is always trainable.
TYPE:
|
min_buffer_size
|
Minimum total data points required in the buffer.
TYPE:
|
min_new_data_count
|
Minimum number of new data points required since last training.
TYPE:
|
Source code in src/pamiq_core/torch/trainer.py
get_torch_training_model ¶
Get a TorchTrainingModel with type checking. Retrieves a PyTorch model training model by name and validates internal model type.
PARAMETER | DESCRIPTION |
---|---|
name
|
Name of the model to retrieve.
TYPE:
|
module_cls
|
Expected internal module class.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
TorchTrainingModel[T]
|
An instance of TorchTrainingModel with specified model type. |
Raises: TypeError: If the model is not a TorchTrainingModel or internal model is not specified module class.
Source code in src/pamiq_core/torch/trainer.py
create_optimizers
abstractmethod
¶
Create and return optimizers and optional schedulers for training. Implementations should create optimizers for each model being trained, and optionally create learning rate schedulers. Returns: Either: - Dictionary mapping names to optimizers - Tuple containing (optimizers dictionary, schedulers dictionary)
Source code in src/pamiq_core/torch/trainer.py
setup ¶
Set up the training environment.
Initializes optimizers and schedulers by calling the configure_optimizers
method and restores their states if previously saved.
Source code in src/pamiq_core/torch/trainer.py
teardown ¶
Clean up after training.
Keeps the current state of optimizers and schedulers before cleanup.
save_state ¶
Save trainer state to disk.
Persists the states of all optimizers and schedulers to the specified directory path.
PARAMETER | DESCRIPTION |
---|---|
path
|
Directory path where state should be saved
TYPE:
|
Source code in src/pamiq_core/torch/trainer.py
load_state ¶
Load trainer state from disk.
Loads the previously saved states of optimizers and schedulers from the specified directory path.
PARAMETER | DESCRIPTION |
---|---|
path
|
Directory path from where state should be loaded
TYPE:
|
RAISES | DESCRIPTION |
---|---|
ValueError
|
If the path does not exist or is not a directory |
Source code in src/pamiq_core/torch/trainer.py
pamiq_core.torch.get_device ¶
Retrieves the device where the module runs.
PARAMETER | DESCRIPTION |
---|---|
module
|
A module that you want to know which device it runs on.
TYPE:
|
default_device
|
A device to return if any device not found.
TYPE:
|
Returns: A device that the module uses or default_device.
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.default_infer_procedure ¶
Default inference procedure with device placement.
This function automatically moves tensor arguments to the same device as the model before performing inference. Non-tensor arguments are passed through unchanged.
PARAMETER | DESCRIPTION |
---|---|
model
|
The model to infer.
TYPE:
|
*args
|
Positional arguments to pass to the model. Tensors will be moved to the model's device.
TYPE:
|
**kwds
|
Keyword arguments to pass to the model. Tensor values will be moved to the model's device.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Any
|
The output from the model's forward pass. |
Note
When overriding this method, ensure that input tensors are properly sent to the correct device to avoid device mismatch.
Source code in src/pamiq_core/torch/model.py
pamiq_core.torch.model.UnwrappedContextManager ¶
Context manager for accessing the raw PyTorch model with thread safety.
This context manager provides direct access to the underlying PyTorch model while ensuring thread safety through locking and optionally enabling/disabling inference mode.
Initialize the context manager.
PARAMETER | DESCRIPTION |
---|---|
model
|
The PyTorch model to provide access to.
TYPE:
|
lock
|
The lock to use for thread synchronization.
TYPE:
|
inference_mode
|
If True, torch.inference_mode will be enabled during the context, disabling gradient computation. If False, gradients will be computed normally.
TYPE:
|
Source code in src/pamiq_core/torch/model.py
__enter__ ¶
Enter the context and return the model.
Acquires the lock and optionally enables inference mode before returning the model for direct access.
RETURNS | DESCRIPTION |
---|---|
T
|
The PyTorch model for direct manipulation. |
Source code in src/pamiq_core/torch/model.py
__exit__ ¶
Exit the context and release resources.
Exits the inference mode context (if enabled) and releases the lock.