Programming/JAX

🧪 Time Series Preprocessing with JAX – The Foundation for Accurate Forecasts

octo54 2025. 7. 15. 14:31
반응형

 

🧪 Time Series Preprocessing with JAX – The Foundation for Accurate Forecasts

Welcome back to the JAX for Time Series series.

Before we dive into deep models like LSTM or Transformer, we need to do something that most tutorials skip over:
proper time series preprocessing.

Why? Because garbage in = garbage out.
Even the most powerful neural networks can’t save you from bad input structure.


⚙️ Step 1: Load and Normalize Your Time Series Data

Let’s start with something simple: a single variable time series — a sine wave + some noise.

import numpy as np
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Simulate time series data
t = np.linspace(0, 20 * np.pi, 1000)
y = np.sin(t) + 0.1 * np.random.randn(1000)

plt.plot(y)
plt.title("Raw Sine Wave with Noise")
plt.show()

Now we normalize it:

scaler = StandardScaler()
y_scaled = scaler.fit_transform(y.reshape(-1, 1)).flatten()

🔁 Standardization is often better than min-max scaling for neural nets.


🪟 Step 2: Create Sliding Windows

반응형

Neural networks don’t “see” sequences unless we format them that way.
So let’s transform the 1D time series into overlapping input-output windows.

def make_windows(series, input_size=50, output_size=10):
    X, Y = [], []
    for i in range(len(series) - input_size - output_size):
        X.append(series[i : i + input_size])
        Y.append(series[i + input_size : i + input_size + output_size])
    return jnp.array(X), jnp.array(Y)

X, Y = make_windows(y_scaled)
print("Input shape:", X.shape)   # e.g. (940, 50)
print("Target shape:", Y.shape)  # e.g. (940, 10)

🧪 Step 3: Train/Test Split

split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
Y_train, Y_test = Y[:split], Y[split:]

print("Train samples:", X_train.shape[0])
print("Test samples:", X_test.shape[0])

📦 Bonus: Batch Generator (for JAX)

To speed up training, let’s build a simple mini-batch function for JAX:

def get_batches(X, Y, batch_size):
    n = X.shape[0]
    indices = np.random.permutation(n)
    for i in range(0, n, batch_size):
        idx = indices[i:i+batch_size]
        yield X[idx], Y[idx]

🧠 What You Have Now

You’ve just created a clean, normalized, batch-ready dataset for time series forecasting — with JAX compatibility.

Task Status

Simulated real-world signal
Preprocessing pipeline
JAX-formatted data
Batched for training

What’s Next?

In the next post, we’ll train our first deep learning model:
a vanilla LSTM using JAX and Flax.

This will be our benchmark model for later experiments like Transformer, VAE, and GPT-based architectures.

Let’s go from preprocessing to actual learning. 🚀


 

JAX, Time Series, Data Preprocessing, Deep Learning, Flax, Neural Forecasting, Sliding Windows, StandardScaler, ML Engineering, Python for AI, JAX Pipeline, AI Forecasting