Structure of an ML model

Ultimately the exact software stack, or framework. you use aren’t that important. Once you begin to grasp the models and their estimation, switching between frameworks and languages around those models becomes easier. I used to think “I code in R, I can’t work with Python” or “I use frequentist SEM models I can’t train neural networks”. While the lower level choices PyTorch or JAX, R or Python are hidden away under high level libraries in parts of this course, it is a good exercise to see the same model represented in different languages and frameworks.

The goal here is twofold, to see the basic structure of an ML training script and to begin to cognitively separate the model, from the language and the framework. Look trough the model specifications below by flickign trough the tabs containing the code, in each you will see the same multi-layer perceptron (MLP) across five deep learning frameworks: PyTorch, torch for R, Keras (R), Keras (Python), and JAX. The goal is to help you see one model, represented in various ways.

Our minimal MLP has:

We’ll describe each matrix and bias term, then express the full model as a symbolic equation.

Matrices and Vectors in the Model

The model uses the following matrices and vectors:

  • \(\mathbf{W}_1 \in \mathbb{R}^{10 \times 20}\): First layer’s weight matrix, mapping 10-dimensional input to a 20-dimensional hidden layer.
  • \(\mathbf{b}_1 \in \mathbb{R}^{1 \times 20}\): Bias vector added after the first linear transformation (can be broadcasted).
  • \(\mathbf{W}_2 \in \mathbb{R}^{20 \times 1}\): Second layer’s weight matrix, mapping the hidden representation to a scalar output.
  • \(\mathbf{b}_2 \in \mathbb{R}^{1 \times 1}\): Final bias term (can be treated as a scalar).
  • \(\mathbf{x} \in \mathbb{R}^{1 \times 10}\): A single input example (can also be a batch of size \(B\): \(\mathbb{R}^{B \times 10}\)).

Mathematical Equation of the Model

The MLP performs the following sequence of computations:

\[ \hat{y} = \mathbf{x} \mathbf{W}_1 + \mathbf{b}_1 \quad \text{(Linear transformation)} \]

\[ \mathbf{h} = \text{ReLU}(\hat{y}) \quad \text{(Nonlinearity applied elementwise)} \]

\[ \mathbf{o} = \mathbf{h} \mathbf{W}_2 + \mathbf{b}_2 \quad \text{(Final linear layer)} \]

Putting it all together, the full function computed by the MLP is:

\[ f(\mathbf{x}) = \left( \text{ReLU}(\mathbf{x} \mathbf{W}_1 + \mathbf{b}_1) \right) \mathbf{W}_2 + \mathbf{b}_2 \]

This is the same whether in PyTorch, Keras, torch for R, or JAX — only the syntax changes, not the structure.

The machinery around the model

These scripts then all contain the following structural elements related tot the model, and to the optimisation of the model given data:

Model definition: This is where the architecture of the neural network is specified, including a forward function, which defines how inputs are transformed through layers to produce outputs.

Ioss-function and optimizer: After the model makes a prediction we need to somehow compare the prediction to the ground truth in the training data, the loss function defines how this comparison is made, and summarized into a single value. The user can also specify an optimizer, which is a mathermatical functino that determines, based on the gradients of the parameters with respect tot he loss in this and sometimes previous iterations, on how to change the parameters in order to minimize the loss.

having defined the model, loss and optimizer the user specifies a training loop:

Forward pass: In this step, input data is passed through the model to compute predictions based on the current weights.

Loss computation: The model’s predictions are compared to the true target values using a loss function that quantifies the prediction error.

Backpropagation: The gradients of the loss with respect to each model parameter are calculated by propagating the error backward through the network.

Weight updates: The model’s parameters are adjusted using an optimization algorithm (like stochastic gradient descent) to reduce the loss in future predictions.

import torch
import torch.nn as nn
import torch.optim as optim

# Model
class MLP(nn.Module):
def \_\_init\_\_(self):
super().\_\_init\_\_()
self.net = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 1))

def forward(self, x):
return self.net(x)

model = MLP()
opt = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# Training loop
for x, y in data:
pred = model(x)
loss = loss_fn(pred, y)
opt.zero_grad()
loss.backward()
opt.step()
library(torch)

# Model
mlp <- nn_module(
  initialize = function() {
    self$fc1 <- nn_linear(10, 20)
    self$fc2 <- nn_linear(20, 1)
  },
  forward = function(x) {
    x %>% self$fc1() %>% nnf_relu() %>% self$fc2()
  }
)

model <- mlp()
opt <- optim_sgd(model$parameters, lr = 0.01)
loss_fn <- nn_mse_loss()

# Training loop
for (batch in dataloader) {
  x <- batch$x
  y <- batch$y
  pred <- model(x)
  loss <- loss_fn(pred, y)
  opt$zero_grad()
  loss$backward()
  opt$step()
}
library(keras)

# Model
model <- keras_model_sequential() %>%
  layer_dense(units = 20, activation = "relu", input_shape = 10) %>%
  layer_dense(units = 1)

model %>% compile(optimizer = "sgd", loss = "mse")

# Training
model %>% fit(x_train, y_train, epochs = 10)
from tensorflow import keras
from tensorflow.keras import layers

# Model
model = keras.Sequential([
    layers.Dense(20, activation='relu', input_shape=(10,)),
    layers.Dense(1)
])
model.compile(optimizer='sgd', loss='mse')

# Training
model.fit(x_train, y_train, epochs=10)
import jax
import jax.numpy as jnp
from jax import grad, jit, random

# Parameters
key = random.PRNGKey(0)
def init_params(key):
    k1, k2 = random.split(key)
    return {
        "W1": random.normal(k1, (10, 20)),
        "b1": jnp.zeros(20),
        "W2": random.normal(k2, (20, 1)),
        "b2": jnp.zeros(1)
    }

def forward(params, x):
    x = jnp.dot(x, params["W1"]) + params["b1"]
    x = jax.nn.relu(x)
    return jnp.dot(x, params["W2"]) + params["b2"]

def loss_fn(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@jit
def train_step(params, x, y, lr=0.01):
    grads = grad(loss_fn)(params, x, y)
    return {k: params[k] - lr * grads[k] for k in params}

params = init_params(key)

# Training loop
for x, y in data:
    params = train_step(params, x, y)

The models clearly are the same, and the different frameworks/languages vary in their verbosity, in keras in R is only a few lines, many common operations such as mean square error loss (mse) are baked in, while in JAX things are defined more manually. I feel very comfortable with PyTorch myself and it does have a wide user base in the language/sequence/image deep-learning communities which means there is a lot of great code available. Switching form one to the other will make you stumble, but its not as insurmountable as many people believe.

As you build bigger and bigger models this just means the model section of the code grows. Usually you’d also want to build in structures to read & validate training data, process it per batch, as well as code to save an shared trained models. Those creature comforts are provided by libraries like transformers which do then tend to abstract away the exact representatino of the model.