🧪 Time Series Preprocessing with JAX – The Foundation for Accurate Forecasts
🧪 Time Series Preprocessing with JAX – The Foundation for Accurate Forecasts
Welcome back to the JAX for Time Series series.
Before we dive into deep models like LSTM or Transformer, we need to do something that most tutorials skip over:
proper time series preprocessing.
Why? Because garbage in = garbage out.
Even the most powerful neural networks can’t save you from bad input structure.
⚙️ Step 1: Load and Normalize Your Time Series Data
Let’s start with something simple: a single variable time series — a sine wave + some noise.
import numpy as np
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
# Simulate time series data
t = np.linspace(0, 20 * np.pi, 1000)
y = np.sin(t) + 0.1 * np.random.randn(1000)
plt.plot(y)
plt.title("Raw Sine Wave with Noise")
plt.show()
Now we normalize it:
scaler = StandardScaler()
y_scaled = scaler.fit_transform(y.reshape(-1, 1)).flatten()
🔁 Standardization is often better than min-max scaling for neural nets.
🪟 Step 2: Create Sliding Windows
Neural networks don’t “see” sequences unless we format them that way.
So let’s transform the 1D time series into overlapping input-output windows.
def make_windows(series, input_size=50, output_size=10):
X, Y = [], []
for i in range(len(series) - input_size - output_size):
X.append(series[i : i + input_size])
Y.append(series[i + input_size : i + input_size + output_size])
return jnp.array(X), jnp.array(Y)
X, Y = make_windows(y_scaled)
print("Input shape:", X.shape) # e.g. (940, 50)
print("Target shape:", Y.shape) # e.g. (940, 10)
🧪 Step 3: Train/Test Split
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
Y_train, Y_test = Y[:split], Y[split:]
print("Train samples:", X_train.shape[0])
print("Test samples:", X_test.shape[0])
📦 Bonus: Batch Generator (for JAX)
To speed up training, let’s build a simple mini-batch function for JAX:
def get_batches(X, Y, batch_size):
n = X.shape[0]
indices = np.random.permutation(n)
for i in range(0, n, batch_size):
idx = indices[i:i+batch_size]
yield X[idx], Y[idx]
🧠 What You Have Now
You’ve just created a clean, normalized, batch-ready dataset for time series forecasting — with JAX compatibility.
Task Status
Simulated real-world signal | ✅ |
Preprocessing pipeline | ✅ |
JAX-formatted data | ✅ |
Batched for training | ✅ |
What’s Next?
In the next post, we’ll train our first deep learning model:
a vanilla LSTM using JAX and Flax.
This will be our benchmark model for later experiments like Transformer, VAE, and GPT-based architectures.
Let’s go from preprocessing to actual learning. 🚀
JAX, Time Series, Data Preprocessing, Deep Learning, Flax, Neural Forecasting, Sliding Windows, StandardScaler, ML Engineering, Python for AI, JAX Pipeline, AI Forecasting