#  pytorch_mlp.py
#  G. Cowan / RHUL Physics / June 2025
#  Simple program to illustrate classification with pytorch
#  Based on simpleClassifier_mlp.py which used scikit-learn

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os, random
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split
from sklearn import metrics

matplotlib.rcParams.update({'font.size':14})     # set all font sizes

# Reproducibility helpers
def seed_everything(seed: int = 42):
    random.seed(seed)                    # Python RNG
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)                 # NumPy RNG
    torch.manual_seed(seed)              # CPU  torch RNG
    torch.cuda.manual_seed(seed)         # Current GPU
    torch.cuda.manual_seed_all(seed)     # All GPUs (DDP)
    
    # Force deterministic algorithms where possible
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False   # turn off autotuner

    # For full determinism with CUDA ≥ 10.2 add *one* of
    # the following **before you start Python** or via os.environ:
    #   os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    #   # or ":16:8"
seed_everything(12345)

#  read the data in from files, 
#  assign target values 1 for signal, 0 for background
sigData = np.loadtxt('signal.txt')
nSig = sigData.shape[0]
sigTargets = np.ones(nSig, dtype=np.int64)

bkgData = np.loadtxt('background.txt')
nBkg = bkgData.shape[0]
bkgTargets = np.zeros(nBkg, dtype=np.int64)

# Concatenate arrays into data X and targets y
X_np = np.concatenate((sigData,bkgData),0)
X_np = X_np[:,0:2]                    # at first, only use x1 and x2
y_np = np.concatenate((sigTargets, bkgTargets))
X = torch.from_numpy(X_np).float()     # float32 features
y = torch.from_numpy(y_np).long()      # int64 class labels

# Use scikit-learn to split into train/test sets
X_train_np, X_test_np, y_train_np, y_test_np = train_test_split( X_np, y_np,
    test_size=0.5, random_state=42, stratify=y_np)
X_train = torch.from_numpy(X_train_np).to(X.dtype)
y_train = torch.from_numpy(y_train_np).to(y.dtype).unsqueeze(1).float()
X_test = torch.from_numpy(X_test_np).to(X.dtype)
y_test  = torch.from_numpy(y_test_np).to(y.dtype).unsqueeze(1).float()

# Define model, e.g., 1 hidden layer with 10 nodes works well,
# 5 hidden layers with 100 nodes each shows significant overtrainng.

model = nn.Sequential(
    nn.Linear(2, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 1)     # scalar output
)

#model = nn.Sequential(
#    nn.Linear(2, 100),
#    nn.ReLU(),
#    nn.Linear(100, 100),
#    nn.ReLU(),
#    nn.Linear(100, 100),
#    nn.ReLU(),
#    nn.Linear(100, 100),
#    nn.ReLU(),
#    nn.Linear(100, 100),
#    nn.ReLU(),
#    nn.Linear(100, 100),
#    nn.ReLU(),
#    nn.Linear(100, 1)     # scalar output
#)


# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(200):
    optimizer.zero_grad()
    output = model(X_train)
    loss = criterion(output, y_train)
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        output_train = model(X_train)
        pred_train = (output_train > 0).float()
        acc_train = (pred_train == y_train).float().mean()  
        output_test = model(X_test)
        pred_test = (output_test > 0).float()
        acc_test = (pred_test == y_test).float().mean()
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | ",
              f"Accuracy (train): {acc_train:.4f} | ",
              f"Accuracy (test) = {acc_test:.4f}")
        
# Make a scatter plot for the training set
bkgTrain = X_train_np[y_train_np == 0]
sigTrain = X_train_np[y_train_np == 1]
fig, ax = plt.subplots(1,1)
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
ax.set_xlim((-2.5,3.5))
ax.set_ylim((-2,4))
x0,x1 = ax.get_xlim()
y0,y1 = ax.get_ylim()
ax.set_aspect(abs(x1-x0)/abs(y1-y0))       # make square plot
xtick_spacing = 0.5
ytick_spacing = 2.0
ax.yaxis.set_major_locator(ticker.MultipleLocator(xtick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(ytick_spacing))
plt.scatter(sigTrain[:,0], sigTrain[:,1], s=3, color='dodgerblue', marker='o')
plt.scatter(bkgTrain[:,0], bkgTrain[:,1], s=3, color='red', marker='o')

# Add decision boundary to scatter plot
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
h = 0.01  # step size in the mesh
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
with torch.no_grad():
    t_values = model(grid)   # gradients not needed
t_values_np = t_values.reshape(xx.shape).numpy()
plt.contour(xx, yy, t_values_np, levels=[0],
            linestyles=['solid'], colors='k')  # or multiple levels, linestyles
plt.xlabel(r'$x_{1}$', labelpad=0)
plt.ylabel(r'$x_{2}$', labelpad=15)
plt.title('training sample')
plt.savefig("x1_x2_scatterplot.pdf", format='pdf')
plt.show()

# Make histogram of decision function
plt.figure()                                     # new window
with torch.no_grad():
    tTest = model(X_test)
tTest_np = tTest.numpy().reshape(y_test_np.shape)
tBkg = tTest_np[y_test_np==0]
tSig = tTest_np[y_test_np==1]
nBins = 50
tMin = np.floor(np.min(tTest_np))
tMax = np.ceil(np.max(tTest_np))
bins = np.linspace(tMin, tMax, nBins+1)
plt.xlabel('decision function $t$', labelpad=3)
plt.ylabel('$f(t)$', labelpad=3)
plt.title('test sample')
n, bins, patches = plt.hist(tSig, bins=bins, density=True, histtype='step',
                            fill=False, color='dodgerblue')
n, bins, patches = plt.hist(tBkg, bins=bins, density=True, histtype='step',
                            fill=False, color='red', alpha=0.5)
plt.savefig("decision_function_hist.pdf", format='pdf')
plt.show()

# And another histogram after mapping onto [0,1] with a sigmoid:
plt.figure()                                     # new window
pBkg = 1. / (1. + np.exp(-tBkg))
pSig = 1. / (1. + np.exp(-tSig))
nBins = 50
pMin = 0.
pMax = 1.
bins = np.linspace(pMin, pMax, nBins+1)
plt.xlabel('class probability $p$', labelpad=3)
plt.ylabel('$f(p)$', labelpad=3)
plt.title('test sample')
n, bins, patches = plt.hist(pSig, bins=bins, density=True, histtype='step',
                            fill=False, color='dodgerblue')
n, bins, patches = plt.hist(pBkg, bins=bins, density=True, histtype='step',
                            fill=False, color='red', alpha=0.5)
plt.savefig("p_hist.pdf", format='pdf')
plt.show()
