pytorch_wrapper

Convert a trained PyTorch model into an evaluator compatible with URANIE Relauncher.

uranie_evaluator(model_torch, outputkey=[])

Parameters:

Parameter

Type

Description

model_torch

torch.nn.Module

A trained PyTorch model. The model must be trained (validated by checking for non-None gradients on parameters).

outputkey

list

Optional list of keys or indices to select specific outputs from multi-output models. For tuple/list outputs, use integers (e.g., [0, 1]). For dict outputs, use strings (e.g., ["mean", "std"]). Default: [] (returns all outputs).

Returns:

Type

Description

callable or None

A Python callable that accepts *args (floats) and returns a 1-D list of floats. Returns None if the model is not trained or arguments are invalid.

Exceptions/Errors:

  • Returns None and prints error message if model_torch is not a valid torch.nn.Module

  • Returns None and prints error message if model has not been trained (no gradients)

  • Returns None and prints error message if outputkey is invalid

  • Prints warning if multi-dimensional outputs are flattened

Examples

One-dimensional example:

import torch
import numpy as np

from uratools.pytorch_wrapper import uranie_evaluator

# Define model
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(4, 8)
        self.fc2 = torch.nn.Linear(8, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Create model and optimizer
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# Generate dummy training data
X_train = torch.randn(50, 4)
y_train = torch.randn(50, 1)

# Training loop
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    predictions = model(X_train)
    loss = criterion(predictions, y_train)
    loss.backward()
    optimizer.step()

# Wrap for URANIE
wrapped = uranie_evaluator(model)
if wrapped:
    result = wrapped(0.1, 0.2, 0.3, 0.4)
    print(result)  # [predicted_value]

Multi-dimensional example:

import torch
from uratools.pytorch_wrapper import uranie_evaluator

class MultiOut(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(4, 16)
        self.out1 = torch.nn.Linear(16, 2)
        self.out2 = torch.nn.Linear(16, 1)
    
    def forward(self, x):
        h = torch.relu(self.fc(x))
        return (self.out1(h), self.out2(h))

# Create model and train
model = MultiOut()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Generate dummy training data
X_train = torch.randn(50, 4)
y1_train = torch.randn(50, 2)
y2_train = torch.randn(50, 1)

# Training loop
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    out1, out2 = model(X_train)
    loss1 = torch.nn.functional.mse_loss(out1, y1_train)
    loss2 = torch.nn.functional.mse_loss(out2, y2_train)
    loss = loss1 + loss2
    loss.backward()
    optimizer.step()

# Select only first output
wrapped = uranie_evaluator(model, outputkey=[0])
if wrapped:
    result = wrapped(0.1, 0.2, 0.3, 0.4)
    print(result)  # [predicted_value_1, predicted_value_2]

Implementation Details

  • Input conversion: Arguments are converted to a 1-row float32 tensor before passing to the model.

  • Output handling: Outputs are converted to Python lists. Multi-dimensional outputs are flattened.

  • Training check: Model is considered trained if at least one parameter has non-None gradients.

  • Inference mode: Model is called in torch.no_grad() context (inference mode).

Supported Model Types

Works with any torch.nn.Module subclass that has been trained:

  • torch.nn.Linear, torch.nn.Sequential, torch.nn.ModuleList

  • torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d

  • torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU

  • torch.nn.TransformerEncoder, torch.nn.TransformerDecoder

  • Custom torch.nn.Module subclasses

  • Pre-trained models from torchvision, torchtext, torchaudio

  • Multi-output models (with outputkey parameter)

Supported Output Types

  • Scalar outputs (int, float)

  • 1-D arrays, tuples, and dictionaries

  • Multi-dimensional tensors

  • Using outputkey parameter for multi-output models