Inference
Using a trained model to make predictions on new data
Inference is the process of using a trained model to make predictions on new, unseen data. Unlike training, inference only requires a forward pass — no gradients are computed.
Example
import torch
model.eval() # switch to evaluation mode
with torch.no_grad(): # disable gradient computation — saves memory
input_data = torch.randn(1, 784, device="cuda")
prediction = model(input_data)
print(prediction.argmax(dim=1)) # predicted class
Training vs. Inference
| Training | Inference | |
|---|---|---|
| Gradients | Yes | No |
| Memory | High (stores activations) | Lower |
| Batch size | Large | Often 1 or small |
| Goal | Learn weights | Make predictions |