Piecewise Linear Curves in PyTorch

In this blog, I will train a simple piecewise linear curve on a dummy data using pytorch. But first, why piecewise linear curve? PWL curves are set of linear equation joined at common points. They allow you to mimic any non linear curve and their simplicity helps you explain the predictions. Moreover, they can be easily deployed and scaled. You can also add constraints like monotonicity in these curves.


So, Let’s get started.

PWL curves can be thought as similar to training a regression model, but instead of the output being function of single curve, it is output of multiple curves.

Bounded vs Unbounded: Linear curves are generally unbounded in nature as they can be directly applied to values outside of their training data. The same can be said about PWL curves as well. One can either set the values as min and max of the curve in training data or apply the first and last curve for out of bound inputs. For the purpose of this example, we will use min and max of the curve values for out of bound inputs. The estimation of range of input values are important for the finding appropriate points to create curve along

Input Range and Curve Points: The first step in training a PWL curve is to determine points between which the linear curves will be created. There are multiple heuristics that can be used here. One strategy could be to use different percentiles of input points across the range. This would ensure that each curve gets an equal number of points for training. Another strategy could be to simply take k equidistant points between min and max. For our example, we implement the latter strategy.

self.xranges = torch.linspace(start, end, ncurves + 1)

PWL Representation and Training Parameters: A PWL Curve can be represented with a set of x points as shown above and a set of y points which represents the value of the curve at the x points. During inference, we identify which curve should be used based on the input and x points and generate prediction using that curve. Since, our x points are fixed, we train the set of y parameters at those x points.

$$ PWL_n = \{(x_1, y_1), (x_2, y_2), …, (x_n, y_n)\} $$

$$ \text{PWL}_n(x) = \text{Linear}(x, x_i, x_{i+1}, y_i, y_{i+1}) \quad \text{for } x_i < x <= x_{i+1} $$

$$ \text{Linear}(x, x_i, x_{i+1}, y_i, y_{i+1}) = \frac{y_{i+1}-y_i}{x_{i+1}-x{i}} (x-x_i) + y_i $$

class PWL(torch.nn.Module):
  def __init__(self, ncurves=3, rng=(0, 5)):
    super().__init__()
    start, end = rng
    self.xpoints = torch.linspace(start, end, ncurves + 1)
    self.weights  = torch.nn.Parameter(torch.randn(ncurves + 1,))

  def calculateLinear(self, x, start, end, ystart, yend):
    # y−y2=(y2−y1)(x2−x1)×(x−x2), equation of straight line
    y = (yend-ystart) / (end-start) * (x-start) + ystart
    return y

  def forward(self, x):
    # This could be optimized using binary search for large number of curves
    if x <= self.xpoints[0]:
      return self.weights[0]
    for i, (start, end) in enumerate(zip(self.xpoints[:-1], self.xpoints[1:])):
      if x > start and x <= end:
        return self.calculateLinear(x, start, end, self.weights[i], self.weights[i+1])
    if x > self.xpoints[-1]:
      return self.weights[-1]
  def forward(self, x):
    # This could be optimized using binary search for large number of curves
    if x <= self.xpoints[0]:
      return self.weights[0]
    for i, (start, end) in enumerate(zip(self.xpoints[:-1], self.xpoints[1:])):
      if x > start and x <= end:
        return self.calculateLinear(x, start, end, self.weights[i], self.weights[i+1])
    if x > self.xpoints[-1]:
      return self.weights[-1]

Training Strategy: We train the PWL curve similar to how we perform regression. We calculate the error between output and prediction and back propagate the loss to our input parameters.

def train(model, x, y):
  optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  for epoch in (pbar:=tqdm.tqdm(range(500))):

      # Forward pass
      y_pred = [model(i) for i in x]
      error = [(pred-act)**2 for pred, act in zip(y_pred, y)]

      # Loss Calculation
      loss = 0
      for e in error:
        loss = loss + e
      loss = loss / len(error)
      loss_scalar = loss.detach().numpy()
      pbar.set_description(f"Epoch: {epoch}, Loss: {loss_scalar:.2f}")

      # Backward pass
      loss.backward()

      # Update the parameters
      optimizer.step()

      # Clear the gradients
      optimizer.zero_grad()

  return loss_scalar

Sample Training and Curve:

Finding Optimal Number of Curves Using Elbow Curve: Elbow curve is generated by calculating the optimum loss while increasing hyperparameter (number of curves). As the number of curves increases, the loss goes down. Because the model is now able to learn more complex characteristics from the data. A sudden drop in loss should be observed as model moves from underfitting to overfitting. After this, only incremental improvements can be observed. The number of curves after the this sudden drop is can be take as number of curves.

Taken from here

Elbow method is generally used for clustering to find number of cluster, but this can be generally applied to any learning algorithm which has a hyper parameter to increase the complexity of learning algorithm. Alternatives to this method can be found in related reads.

Advantage over Polynomial Regression?

  • Less over fitting to number of data points: polynomial regression tends to overfit to locations where the number of samples are higher (because they contribute higher to losses)
  • Simple to explain and understand the curve

Closing and Future Direction

PWL Curves are generally used to train GAM models (Generalized additive models) as these can be trained individually for multiple features. You can find the complete code along with working example in this colab. An important extension from this could be to learn the split points from data instead of using heuristics as above. Thanks for reading and keep learning!

PS: The current implementation in PyTorch is not efficient as it doesn’t accept the batch input. Leaving that as an exercise.

Related Reads:


Subscribe

Please enable JavaScript in your browser to complete this form.
Name