Skip to content

Torch🔥

pamiq_core.torch.TorchTrainingModel

TorchTrainingModel(
    model: T,
    has_inference_model: bool = True,
    inference_thread_only: bool = False,
    device: device | str | None = None,
    dtype: dtype | None = None,
    inference_procedure: InferenceProcedureCallable[T] | str = default_infer_procedure,
    pretrained_parameter_file: str | Path | None = None,
    compile: bool = False,
)

Bases: TrainingModel[TorchInferenceModel[T]]

PyTorch model wrapper for parallel training and inference.

This class enables efficient multi-threaded operation where training and inference can run in parallel on separate model instances. It manages model synchronization between threads and provides various initialization options.

Type Parameters

T: The type of the PyTorch model (must be nn.Module subclass).

Initialize the TorchTrainingModel.

PARAMETER DESCRIPTION
model

The PyTorch model to wrap for training.

TYPE: T

has_inference_model

Whether to create a separate inference model. If False, inference is not supported.

TYPE: bool DEFAULT: True

inference_thread_only

If True, the same model instance is shared between training and inference (no copying). Use this when the model is only used for inference.

TYPE: bool DEFAULT: False

device

Device to place the model on. If None, keeps the model on its current device.

TYPE: device | str | None DEFAULT: None

dtype

Data type for the model parameters. If specified, converts the model to this dtype.

TYPE: dtype | None DEFAULT: None

inference_procedure

The procedure to use for inference. Can be: - A callable following the InferenceProcedureCallable protocol - A string naming a method on the model class - The default_infer_procedure function (default)

TYPE: InferenceProcedureCallable[T] | str DEFAULT: default_infer_procedure

pretrained_parameter_file

Path to a pre-trained model parameter file. If provided, loads parameters from this file after initialization.

TYPE: str | Path | None DEFAULT: None

compile

Whether to compile the model using torch.compile() for potentially better performance. If has_inference_model is True, both models are compiled.

TYPE: bool DEFAULT: False

RAISES DESCRIPTION
AttributeError

If inference_procedure is a string but doesn't exist in model class attributes.

ValueError

If inference_procedure is a string but doesn't refer to a callable method on the model class.

Source code in src/pamiq_core/torch/model.py
@override
def __init__(
    self,
    model: T,
    has_inference_model: bool = True,
    inference_thread_only: bool = False,
    device: torch.device | str | None = None,
    dtype: torch.dtype | None = None,
    inference_procedure: InferenceProcedureCallable[T]
    | str = default_infer_procedure,
    pretrained_parameter_file: str | Path | None = None,
    compile: bool = False,
):
    """Initialize the TorchTrainingModel.

    Args:
        model: The PyTorch model to wrap for training.
        has_inference_model: Whether to create a separate inference model.
            If False, inference is not supported.
        inference_thread_only: If True, the same model instance is shared
            between training and inference (no copying). Use this when
            the model is only used for inference.
        device: Device to place the model on. If None, keeps the model
            on its current device.
        dtype: Data type for the model parameters. If specified, converts
            the model to this dtype.
        inference_procedure: The procedure to use for inference. Can be:
            - A callable following the InferenceProcedureCallable protocol
            - A string naming a method on the model class
            - The default_infer_procedure function (default)
        pretrained_parameter_file: Path to a pre-trained model parameter
            file. If provided, loads parameters from this file after
            initialization.
        compile: Whether to compile the model using torch.compile() for
            potentially better performance. If has_inference_model is True,
            both models are compiled.

    Raises:
        AttributeError: If inference_procedure is a string but doesn't exist in model class attributes.
        ValueError: If inference_procedure is a string but doesn't refer
            to a callable method on the model class.
    """
    super().__init__(has_inference_model, inference_thread_only)
    if dtype is not None:
        model = model.type(dtype)
    self.model = model
    if device is None:  # prevents from moving the model to cpu unintentionally.
        device = get_device(model)

    if isinstance(inference_procedure, str):
        method_name = inference_procedure
        if not hasattr(model.__class__, method_name):
            raise AttributeError(
                f"The model class {model.__class__.__name__} does not have "
                f"a method named '{method_name}'"
            )
        inference_procedure = getattr(model.__class__, method_name)
        if not callable(inference_procedure):
            raise ValueError(
                f"The specified inference_procedure '{method_name}' "
                f"is not a callable method on {model.__class__.__name__}"
            )

    self._inference_procedure = inference_procedure
    self.model.to(device)

    if pretrained_parameter_file is not None:
        self.model.load_state_dict(
            torch.load(pretrained_parameter_file, map_location=device)  # pyright: ignore[reportUnknownMemberType]
        )

    if compile:
        if self.has_inference_model:
            # copy before compile
            self.inference_model._raw_model.compile()  # pyright: ignore[reportPrivateUsage, reportUnknownMemberType]
        self.model.compile()  # pyright: ignore[reportUnknownMemberType, ]

sync_impl

sync_impl(inference_model: TorchInferenceModel[T]) -> None

Synchronize training model parameters to the inference model.

This method implements an efficient parameter synchronization strategy by swapping model references and copying state dictionaries. It preserves gradients on the training model during the sync operation.

PARAMETER DESCRIPTION
inference_model

The inference model to synchronize parameters to.

TYPE: TorchInferenceModel[T]

Note

The models are put in eval mode during sync and returned to train mode afterwards to ensure proper behavior of layers like BatchNorm and Dropout.

Source code in src/pamiq_core/torch/model.py
@override
def sync_impl(self, inference_model: TorchInferenceModel[T]) -> None:
    """Synchronize training model parameters to the inference model.

    This method implements an efficient parameter synchronization strategy
    by swapping model references and copying state dictionaries. It preserves
    gradients on the training model during the sync operation.

    Args:
        inference_model: The inference model to synchronize parameters to.

    Note:
        The models are put in eval mode during sync and returned to train
        mode afterwards to ensure proper behavior of layers like BatchNorm
        and Dropout.
    """

    self.model.eval()

    # Hold the grads.
    grads: list[torch.Tensor | None] = []
    for p in self.model.parameters():
        grads.append(p.grad)
        p.grad = None

    # Swap the training model and the inference model.
    self.model, inference_model._raw_model = (  # pyright: ignore[reportPrivateUsage]
        inference_model._raw_model,  # pyright: ignore[reportPrivateUsage]
        self.model,
    )
    self.model.load_state_dict(
        self.inference_model._raw_model.state_dict()  # pyright: ignore[reportPrivateUsage]
    )

    # Assign the model grads.
    for i, p in enumerate(self.model.parameters()):
        p.grad = grads[i]

    self.model.train()

forward

forward(*args: Any, **kwds: Any) -> Any

Forward pass through the training model.

PARAMETER DESCRIPTION
*args

Positional arguments to pass to the model.

TYPE: Any DEFAULT: ()

**kwds

Keyword arguments to pass to the model.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Any

The output from the model's forward pass.

Source code in src/pamiq_core/torch/model.py
@override
def forward(self, *args: Any, **kwds: Any) -> Any:
    """Forward pass through the training model.

    Args:
        *args: Positional arguments to pass to the model.
        **kwds: Keyword arguments to pass to the model.

    Returns:
        The output from the model's forward pass.
    """
    return self.model(*args, **kwds)

save_state

save_state(path: Path) -> None

Save the model parameters.

PARAMETER DESCRIPTION
path

Base path for saving the model state. The actual file will be saved as "{path}.pt".

TYPE: Path

Source code in src/pamiq_core/torch/model.py
@override
def save_state(self, path: Path) -> None:
    """Save the model parameters.

    Args:
        path: Base path for saving the model state. The actual file
            will be saved as "{path}.pt".
    """
    torch.save(self.model.state_dict(), f"{path}.pt")  # pyright: ignore[reportUnknownMemberType]

load_state

load_state(path: Path) -> None

Load model parameters.

PARAMETER DESCRIPTION
path

Base path for loading the model state. The actual file loaded will be "{path}.pt".

TYPE: Path

Source code in src/pamiq_core/torch/model.py
@override
def load_state(self, path: Path) -> None:
    """Load model parameters.

    Args:
        path: Base path for loading the model state. The actual file
            loaded will be "{path}.pt".
    """
    self.model.load_state_dict(torch.load(f"{path}.pt"))  # pyright: ignore[reportUnknownMemberType]

pamiq_core.torch.TorchInferenceModel

TorchInferenceModel(model: T, inference_procedure: InferenceProcedureCallable[T])

Bases: InferenceModel

Thread-safe wrapper for PyTorch models used in inference.

This class provides a thread-safe interface for performing inference with PyTorch models in multi-threaded environments. It uses a lock to ensure that model updates and inference operations don't interfere with each other.

Type Parameters

T: The type of the PyTorch model (must be nn.Module subclass).

Initialize the TorchInferenceModel.

PARAMETER DESCRIPTION
model

A PyTorch model to wrap for thread-safe inference.

TYPE: T

inference_procedure

A callable that defines how to perform inference with the model. It should specify the model as the first argument, followed by additional arguments.

TYPE: InferenceProcedureCallable[T]

Source code in src/pamiq_core/torch/model.py
def __init__(
    self, model: T, inference_procedure: InferenceProcedureCallable[T]
) -> None:
    """Initialize the TorchInferenceModel.

    Args:
        model: A PyTorch model to wrap for thread-safe inference.
        inference_procedure: A callable that defines how to perform
            inference with the model. It should specify the model as the
            first argument, followed by additional arguments.
    """
    self._model = model
    self._inference_procedure = inference_procedure
    self._lock = RLock()

infer

infer(*args: Any, **kwds: Any) -> Any

Perform thread-safe inference with gradient computation disabled.

This method executes the inference procedure with the model while ensuring thread safety through locking and disabling gradient computation for efficiency.

PARAMETER DESCRIPTION
*args

Positional arguments to pass to the inference procedure.

TYPE: Any DEFAULT: ()

**kwds

Keyword arguments to pass to the inference procedure.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Any

The output from the inference procedure.

Source code in src/pamiq_core/torch/model.py
@torch.inference_mode()
@override
def infer(self, *args: Any, **kwds: Any) -> Any:
    """Perform thread-safe inference with gradient computation disabled.

    This method executes the inference procedure with the model while
    ensuring thread safety through locking and disabling gradient
    computation for efficiency.

    Args:
        *args: Positional arguments to pass to the inference procedure.
        **kwds: Keyword arguments to pass to the inference procedure.

    Returns:
        The output from the inference procedure.
    """
    with self._lock:
        return self._inference_procedure(self._model, *args, **kwds)

pamiq_core.torch.TorchTrainer

TorchTrainer(
    training_condition_data_user: str | None = None,
    min_buffer_size: int = 0,
    min_new_data_count: int = 0,
)

Bases: Trainer

Base class for PyTorch model training in pamiq-core.

This trainer integrates PyTorch models with the pamiq-core framework, providing functionality for optimizer configuration, state persistence, and model type validation. It automatically handles the setup and teardown of optimizers and learning rate schedulers during the training process.

Subclasses should implement the configure_optimizers and train methods to define the specific training behavior.

Initialize the PyTorch trainer.

Sets up empty containers for optimizers, schedulers, and their respective states. Actual optimizer and scheduler instances will be created during the setup phase.

PARAMETER DESCRIPTION
training_condition_data_user

Name of the data user to check for trainability. If None, trainer is always trainable.

TYPE: str | None DEFAULT: None

min_buffer_size

Minimum total data points required in the buffer.

TYPE: int DEFAULT: 0

min_new_data_count

Minimum number of new data points required since last training.

TYPE: int DEFAULT: 0

Source code in src/pamiq_core/torch/trainer.py
def __init__(
    self,
    training_condition_data_user: str | None = None,
    min_buffer_size: int = 0,
    min_new_data_count: int = 0,
) -> None:
    """Initialize the PyTorch trainer.

    Sets up empty containers for optimizers, schedulers, and their
    respective states. Actual optimizer and scheduler instances will
    be created during the setup phase.

    Args:
        training_condition_data_user: Name of the data user to check for trainability.
            If None, trainer is always trainable.
        min_buffer_size: Minimum total data points required in the buffer.
        min_new_data_count: Minimum number of new data points required since last training.
    """
    super().__init__(
        training_condition_data_user,
        min_buffer_size,
        min_new_data_count,
    )

    # Containers for optimizer and scheduler instances
    self.optimizers: OptimizersDict = {}
    self.lr_schedulers: LRSchedulersDict = {}

    # Containers for persistent optimizer and scheduler states
    self.optimizer_states: dict[str, StateDict] = {}
    self.lr_scheduler_states: dict[str, StateDict] = {}

get_torch_training_model

get_torch_training_model(
    name: str, module_cls: type[T] = nn.Module
) -> TorchTrainingModel[T]

Get a TorchTrainingModel with type checking. Retrieves a PyTorch model training model by name and validates internal model type.

PARAMETER DESCRIPTION
name

Name of the model to retrieve.

TYPE: str

module_cls

Expected internal module class.

TYPE: type[T] DEFAULT: Module

RETURNS DESCRIPTION
TorchTrainingModel[T]

An instance of TorchTrainingModel with specified model type.

Raises: TypeError: If the model is not a TorchTrainingModel or internal model is not specified module class.

Source code in src/pamiq_core/torch/trainer.py
def get_torch_training_model[T: nn.Module](
    self, name: str, module_cls: type[T] = nn.Module
) -> TorchTrainingModel[T]:
    """Get a TorchTrainingModel with type checking. Retrieves a PyTorch
    model training model by name and validates internal model type.

    Args:
        name: Name of the model to retrieve.
        module_cls: Expected internal module class.

    Returns:
        An instance of TorchTrainingModel with specified model type.
    Raises:
        TypeError: If the model is not a TorchTrainingModel or internal model is not specified module class.
    """
    training_model = self.get_training_model(name)
    if not isinstance(training_model, TorchTrainingModel):
        raise TypeError(f"Model {name} is not a instance of TorchTrainingModel")

    training_model = cast(TorchTrainingModel[T], training_model)

    if not isinstance(training_model.model, module_cls):
        raise TypeError(
            f"Internal model is not a instance of {module_cls.__name__}"
        )
    return training_model

create_optimizers abstractmethod

create_optimizers() -> OptimizersSetup

Create and return optimizers and optional schedulers for training. Implementations should create optimizers for each model being trained, and optionally create learning rate schedulers. Returns: Either: - Dictionary mapping names to optimizers - Tuple containing (optimizers dictionary, schedulers dictionary)

Source code in src/pamiq_core/torch/trainer.py
@abstractmethod
def create_optimizers(self) -> OptimizersSetup:
    """Create and return optimizers and optional schedulers for training.
    Implementations should create optimizers for each model being trained,
    and optionally create learning rate schedulers.
    Returns:
        Either:
        - Dictionary mapping names to optimizers
        - Tuple containing (optimizers dictionary, schedulers dictionary)
    """
    pass

setup

setup() -> None

Set up the training environment.

Initializes optimizers and schedulers by calling the configure_optimizers method and restores their states if previously saved.

Source code in src/pamiq_core/torch/trainer.py
@override
def setup(self) -> None:
    """Set up the training environment.

    Initializes optimizers and schedulers by calling the `configure_optimizers`
    method and restores their states if previously saved.
    """
    super().setup()
    self._setup_optimizers_and_schedulers()

teardown

teardown() -> None

Clean up after training.

Keeps the current state of optimizers and schedulers before cleanup.

Source code in src/pamiq_core/torch/trainer.py
@override
def teardown(self) -> None:
    """Clean up after training.

    Keeps the current state of optimizers and schedulers before
    cleanup.
    """
    super().teardown()
    self._keep_optimizer_and_scheduler_states()

save_state

save_state(path: Path) -> None

Save trainer state to disk.

Persists the states of all optimizers and schedulers to the specified directory path.

PARAMETER DESCRIPTION
path

Directory path where state should be saved

TYPE: Path

Source code in src/pamiq_core/torch/trainer.py
@override
def save_state(self, path: Path) -> None:
    """Save trainer state to disk.

    Persists the states of all optimizers and schedulers to the specified
    directory path.

    Args:
        path: Directory path where state should be saved
    """
    super().save_state(path)

    # Before this method performed, the state of the optimizer or LR scheduler have already been kept by the `teardown` method.

    # Save optimizer states to disk
    for name, optimizer_state in self.optimizer_states.items():
        torch.save(optimizer_state, path / f"{name}.optim.pt")  # pyright: ignore[reportUnknownMemberType]

    # Save scheduler states to disk
    for name, scheduler_state in self.lr_scheduler_states.items():
        torch.save(scheduler_state, path / f"{name}.lrsch.pt")  # pyright: ignore[reportUnknownMemberType]

load_state

load_state(path: Path) -> None

Load trainer state from disk.

Loads the previously saved states of optimizers and schedulers from the specified directory path.

PARAMETER DESCRIPTION
path

Directory path from where state should be loaded

TYPE: Path

RAISES DESCRIPTION
ValueError

If the path does not exist or is not a directory

Source code in src/pamiq_core/torch/trainer.py
@override
def load_state(self, path: Path) -> None:
    """Load trainer state from disk.

    Loads the previously saved states of optimizers and schedulers from
    the specified directory path.

    Args:
        path: Directory path from where state should be loaded

    Raises:
        ValueError: If the path does not exist or is not a directory
    """
    if not path.is_dir():
        raise ValueError(f"Path {path} is not a directory or does not exist")

    super().load_state(path)

    # Load optimizer states from disk
    for optimizer_path in path.glob("*.optim.pt"):
        name = optimizer_path.name.replace(".optim.pt", "")
        self.optimizer_states[name] = torch.load(optimizer_path)  # pyright: ignore[reportUnknownMemberType]

    # Load scheduler states from disk
    for scheduler_path in path.glob("*.lrsch.pt"):
        name = scheduler_path.name.replace(".lrsch.pt", "")
        self.lr_scheduler_states[name] = torch.load(scheduler_path)  # pyright: ignore[reportUnknownMemberType]

pamiq_core.torch.get_device

get_device(module: Module, default_device: device | None = None) -> torch.device

Retrieves the device where the module runs.

PARAMETER DESCRIPTION
module

A module that you want to know which device it runs on.

TYPE: Module

default_device

A device to return if any device not found.

TYPE: device | None DEFAULT: None

Returns: A device that the module uses or default_device.

Source code in src/pamiq_core/torch/model.py
def get_device(
    module: nn.Module, default_device: torch.device | None = None
) -> torch.device:
    """Retrieves the device where the module runs.

    Args:
        module: A module that you want to know which device it runs on.
        default_device: A device to return if any device not found.
    Returns:
        A device that the module uses or default_device.
    """
    for param in module.parameters():
        return param.device
    for buf in module.buffers():
        return buf.device
    if default_device is None:
        default_device = torch.get_default_device()
    return default_device

pamiq_core.torch.default_infer_procedure

default_infer_procedure(model: Module, *args: Any, **kwds: Any) -> Any

Default inference procedure with device placement.

This function automatically moves tensor arguments to the same device as the model before performing inference. Non-tensor arguments are passed through unchanged.

PARAMETER DESCRIPTION
model

The model to infer.

TYPE: Module

*args

Positional arguments to pass to the model. Tensors will be moved to the model's device.

TYPE: Any DEFAULT: ()

**kwds

Keyword arguments to pass to the model. Tensor values will be moved to the model's device.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Any

The output from the model's forward pass.

Note

When overriding this method, ensure that input tensors are properly sent to the correct device to avoid device mismatch.

Source code in src/pamiq_core/torch/model.py
def default_infer_procedure(model: nn.Module, *args: Any, **kwds: Any) -> Any:
    """Default inference procedure with device placement.

    This function automatically moves tensor arguments to the same device
    as the model before performing inference. Non-tensor arguments are
    passed through unchanged.

    Args:
        model: The model to infer.
        *args: Positional arguments to pass to the model. Tensors will be
            moved to the model's device.
        **kwds: Keyword arguments to pass to the model. Tensor values will
            be moved to the model's device.

    Returns:
        The output from the model's forward pass.

    Note:
        When overriding this method, ensure that input tensors are properly
        sent to the correct device to avoid device mismatch.
    """
    device = get_device(model)
    new_args: list[Any] = []
    new_kwds: dict[str, Any] = {}
    for i in args:
        if isinstance(i, torch.Tensor):
            i = i.to(device)
        new_args.append(i)

    for k, v in kwds.items():
        if isinstance(v, torch.Tensor):
            v = v.to(device)
        new_kwds[k] = v

    return model(*new_args, **new_kwds)