import gzip
import json
import numpy as np
from numpy import gradient
import torch
import torch.nn as nn
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from sklearn import metrics
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors
import time
import re
import shap
from pyswarm import pso

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#%%
# Load the saved JSON file {Tesla dataset}
with gzip.open('data_all_cells.json.gz', 'rt', encoding='utf-8') as f:
    data_all_cells = json.load(f)
# Extract all battery names
all_batteries = list(data_all_cells.keys())
all_battery_num = [re.search(r'_(\d+)_', entry).group(1) for entry in all_batteries]
all_battery_num_stripped = [num.lstrip('0') for num in all_battery_num]

# read the defined training cell numbers
with open('train_cells.txt', 'r') as file:
    train_cells_content = file.read()
train_cells = train_cells_content.splitlines()
train_batteries_indices = [all_battery_num_stripped.index(cell) for cell in train_cells if cell in all_battery_num_stripped]
train_batteries = [all_batteries[i] for i in train_batteries_indices]

# read the defined testing cell nubmers
with open('test_cells.txt', 'r') as file:
    test_cells_content = file.read()
test_cells = test_cells_content.splitlines()
test_batteries_indices = [all_battery_num_stripped.index(cell) for cell in test_cells if cell in all_battery_num_stripped]
test_batteries = [all_batteries[i] for i in test_batteries_indices]

#%% Load half cell data and fitting function construction
folder_loc = '..\\006_half_cell'
file_name = 'anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv'
file_path = f'{file_name}'
OCPn_data = pd.read_csv(file_path)
file_name = 'cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv'
file_path = f'{file_name}'
OCPp_data = pd.read_csv(file_path)
OCPn_SOC = torch.tensor(OCPn_data['SOC_linspace'].values, dtype=torch.float32).to('cuda')  # 
OCPn_V = torch.tensor(OCPn_data['Voltage'].values, dtype=torch.float32).to('cuda')
OCPp_SOC = torch.tensor(OCPp_data['SOC_linspace'].values, dtype=torch.float32).to('cuda')
OCPp_V = torch.tensor(OCPp_data['Voltage'].values, dtype=torch.float32).to('cuda')
def torch_interpolate(x, xp, fp):
    indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1)
    x0 = xp[indices - 1]
    x1 = xp[indices]
    f0 = fp[indices - 1]
    f1 = fp[indices]
    return f0 + (f1 - f0) * (x - x0) / (x1 - x0)

def OCP_p(SOC_p):
    SOC_p = SOC_p.to(OCPp_SOC.device)  
    return torch_interpolate(SOC_p, OCPp_SOC, OCPp_V)

def OCP_n(SOC_n):
    SOC_n = SOC_n.to(OCPn_SOC.device)  
    return torch_interpolate(SOC_n, OCPn_SOC, OCPn_V)
#%%  extract data for health diagnosis
def extract_and_interpolate(data, idx_seed, num_points=200):
    """
    Extract multiple 'rpt' samples for the 1C charging process (current, voltage, and charge capacity) and interpolate them into num_points data points as input.
    Also extract the 0.2C discharge process data (capacity-voltage) and interpolate it into num_points data points as output.
    Returns multiple samples with the shape (num_samples, num_points, 3).
    """
    nominal_capacity = 4.84
    inputs = []
    outputs = []
    start_v = []
    end_v = []
    cap_pari_work = []
    soc_range_work = []
    # print(data['protocol'])
    for rpt_data in data['rpt']:
        for round_idx in range(3): # select to sample the partial charging curves multiple times
            set_random_seed(idx_seed+round_idx) # set random seed to make each sample different
            if 'work' in rpt_data and '02C' in rpt_data:
                # print('work')
                if rpt_data is data['rpt'][-1]:
                    # print('continue')
                    continue
                
                charge_current = np.array(rpt_data['work']['current'])
                charge_voltage = np.array(rpt_data['work']['voltage'])
                # plt.plot(charge_voltage)
                # print(charge_voltage)
                charge_capacity = np.array(rpt_data['work']['charge_capacity'])
                EFCs = np.array(rpt_data['work']['EFC'])
                if charge_current.size < 50:
                    continue
                
                for charge_end_idx in range(2,len(charge_current)):
                    if charge_voltage[charge_end_idx]<charge_voltage[charge_end_idx-1]:
                        break
                # Randomly select start_index 
                start_index = random.randint(0, int(0.3 * charge_end_idx))
                # Randomly select end_index 
                end_index = random.randint(int(0.7 * charge_end_idx), charge_end_idx)
                # Extract charging data from start_index to end_index
                charge_current = charge_current[start_index:end_index] / nominal_capacity
                charge_voltage = charge_voltage[start_index:end_index] / 4.2
                charge_capacity = charge_capacity[start_index:end_index] / nominal_capacity-charge_capacity[start_index] / nominal_capacity
                EFCs = EFCs[start_index:end_index] / 500
                if charge_current.size<50 :
                    continue
                
                # Interpolation to obtain num_points data points
                x_new = np.linspace(0, 1, num_points)
                interp_current = interp1d(np.linspace(0, 1, len(charge_current)), charge_current, kind='linear')
                interp_voltage = interp1d(np.linspace(0, 1, len(charge_voltage)), charge_voltage, kind='linear')
                interp_capacity = interp1d(np.linspace(0, 1, len(charge_capacity)), charge_capacity, kind='linear')
                interp_EFCs = interp1d(np.linspace(0, 1, len(EFCs)), EFCs, kind='linear')
                charge_current_interp = interp_current(x_new)
                charge_voltage_interp = interp_voltage(x_new)
                charge_capacity_interp = interp_capacity(x_new)
                charge_EFCs_interp = interp_EFCs(x_new)
                
                # Extract the 0.2C discharge process (capacity and voltage) and perform interpolation
                discharge_capacity = np.array(rpt_data['02C']['discharge_capacity'])
                discharge_voltage = np.array(rpt_data['02C']['voltage'])
                dis_start_idx = np.where(discharge_capacity <=0)[0]
                if len(dis_start_idx) > 0:
                    dis_start_idx = dis_start_idx[-1] 
                else:
                    dis_start_idx = len(discharge_capacity)  
                
                discharge_capacity = discharge_capacity[dis_start_idx:]/nominal_capacity
                discharge_voltage = discharge_voltage[dis_start_idx:]/ 4.2
                interp_ocv_v = interp1d(np.linspace(0, 1, len(discharge_voltage)), discharge_voltage, kind='linear')
                interp_ocv_q = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity, kind='linear')
                discharge_voltage_interp = interp_ocv_v(x_new)
                discharge_capacity_interp = interp_ocv_q(x_new)
                discharge_capacity_diff = discharge_capacity_interp[1:] - discharge_capacity_interp[ :-1]
                if max(discharge_capacity_diff)>0.05 or max(discharge_capacity)<0.7 or max(charge_capacity)*nominal_capacity<0.2: #  or max(discharge_capacity)<0.7
                    # print('skip')
                    continue
                # start_v.append(charge_voltage[0]*4.2)
                # end_v.append(charge_voltage[end_index-start_index-1]*4.2)
                cap_pari_work.append(np.stack([max(charge_capacity)*nominal_capacity, max(discharge_capacity)*nominal_capacity], axis=-1))

                # Generate 1C charging input (200, 3))
                input_data = np.stack([charge_voltage_interp, charge_capacity_interp, charge_EFCs_interp], axis=-1) #
                output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
                
                soc_range_work.append(max(charge_capacity)/max(discharge_capacity))
                start_v.append(charge_voltage[0]*4.2)
                end_v.append(charge_voltage[end_index-start_index-1]*4.2)
                # Add the processed sample to the list
                inputs.append(input_data)
                outputs.append(output_data)    
            
    # Convert input and output lists to NumPy arrays
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)
    return inputs, outputs, soc_range_work, start_v, end_v

# Construct training and test datasets
train_inputs, train_outputs = [], []
test_inputs, test_outputs = [], []
train_start_v, train_end_v, train_soc_range_work= [], [], []
test_start_v, test_end_v, test_soc_range_work= [], [], []
nominal_capacity = 4.84
# Process the battery data in the training set
plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
for i in range(0,len(train_batteries)): # train_batteries all_batteries 
    battery = train_batteries[i]
    print("Train cell", battery)
    cap = []
    input_data, output_data, soc_range_work_data, start_v_data, end_v_data = extract_and_interpolate(data_all_cells[battery],i)
    train_inputs.append(input_data)
    train_outputs.append(output_data)
    train_start_v.append(start_v_data)
    train_end_v.append(end_v_data)
    train_soc_range_work.append(soc_range_work_data)
    plt.plot(input_data[:,-1,2]*500,output_data[:,-1,1]*nominal_capacity, 'grey')
plt.plot(np.arange(0,1500,100),0.8*nominal_capacity*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.plot(np.arange(0,1500,100),0.75*nominal_capacity*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.plot(np.arange(0,1500,100),0.7*nominal_capacity*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.xlabel('EFCs')
plt.ylabel('Capacity [Ah]')
plt.show()
plt.show()

plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# Process the battery data in the test set
for i in range(0,len(test_batteries)):
    battery = test_batteries[i]
    # set_random_seed(i)
    print("Test cell", battery)
    input_data, output_data, soc_range_work_data, start_v_data, end_v_data = extract_and_interpolate(data_all_cells[battery],i)
    test_inputs.append(input_data)
    test_outputs.append(output_data)
    test_start_v.append(start_v_data)
    test_end_v.append(end_v_data)
    test_soc_range_work.append(soc_range_work_data)
    plt.plot(input_data[:,-1,2]*500,output_data[:,-1,1]*nominal_capacity, 'grey')
plt.plot(np.arange(0,1500,100),0.8*nominal_capacity*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.plot(np.arange(0,1500,100),0.75*nominal_capacity*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.plot(np.arange(0,1500,100),0.7*nominal_capacity*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.xlabel('EFCs')
plt.ylabel('Capacity [Ah]')
plt.show()
# Convert to NumPy arrays
train_inputs = np.concatenate(train_inputs, axis=0)
train_outputs = np.concatenate(train_outputs, axis=0)
test_inputs = np.concatenate(test_inputs, axis=0)
test_outputs = np.concatenate(test_outputs, axis=0)

train_start_v = np.concatenate(train_start_v, axis=0)
train_end_v = np.concatenate(train_end_v, axis=0)
train_soc_range_work = np.concatenate(train_soc_range_work, axis=0)
test_start_v = np.concatenate(test_start_v, axis=0)
test_end_v = np.concatenate(test_end_v, axis=0)
test_soc_range_work = np.concatenate(test_soc_range_work, axis=0)

print(f"Training data shape: {train_inputs.shape}, {train_outputs.shape}")
print(f"Test data shape: {test_inputs.shape}, {test_outputs.shape}")
#%%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
class EncoderGRU(nn.Module):
    def __init__(self):
        super(EncoderGRU, self).__init__()
        self.rnn = nn.GRU(input_size=3, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Linear(64, 64)
        self.cp_out = nn.Linear(64, 1)
        self.cn_out = nn.Linear(64, 1)
        self.x0_out = nn.Linear(64, 1)
        self.y0_out = nn.Linear(64, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
    
    def forward(self, x):
        x, _ = self.rnn(x)
        x = torch.relu(self.fc(x[:, -1, :]))  
        Cp = torch.relu(self.cp_out(x))  
        Cn = torch.relu(self.cn_out(x))   
        x0 = torch.relu(self.x0_out(x))  
        y0 = torch.relu(self.y0_out(x))   
        return Cp, Cn, x0, y0

class EncoderRNN(nn.Module):
    def __init__(self):
        super(EncoderRNN, self).__init__()
        self.rnn = nn.RNN(input_size=3, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Linear(64, 64)
        self.cp_out = nn.Linear(64, 1)
        self.cn_out = nn.Linear(64, 1)
        self.x0_out = nn.Linear(64, 1)
        self.y0_out = nn.Linear(64, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
    
    def forward(self, x):
        x, _ = self.rnn(x)
        x = torch.relu(self.fc(x[:, -1, :])) 
        Cp = torch.relu(self.cp_out(x))  
        Cn = torch.relu(self.cn_out(x))   
        x0 = torch.relu(self.x0_out(x))  
        y0 = torch.relu(self.y0_out(x))   
        return Cp, Cn, x0, y0
    
    
class EncoderLSTM(nn.Module):
    def __init__(self):
        super(EncoderLSTM, self).__init__()
        self.rnn = nn.LSTM(input_size=3, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Linear(64, 64)
        self.cp_out = nn.Linear(64, 1)
        self.cn_out = nn.Linear(64, 1)
        self.x0_out = nn.Linear(64, 1)
        self.y0_out = nn.Linear(64, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
    
    def forward(self, x):
        x, _ = self.rnn(x)
        x = torch.relu(self.fc(x[:, -1, :]))  
        Cp = torch.relu(self.cp_out(x))  
        Cn = torch.relu(self.cn_out(x))   
        x0 = torch.relu(self.x0_out(x))  
        y0 = torch.relu(self.y0_out(x))   
        return Cp, Cn, x0, y0

class Encoder1DCNN(nn.Module):
    def __init__(self, num_features=3, num_filters=64, kernel_size=3, hidden_size=64):
        super(Encoder1DCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=num_features, out_channels=int(1*num_filters), kernel_size=kernel_size, padding=1)
        self.conv2 = nn.Conv1d(in_channels=int(1*num_filters), out_channels=num_filters, kernel_size=kernel_size, padding=1)
        self.conv3 = nn.Conv1d(in_channels=num_filters, out_channels=int(1*num_filters), kernel_size=kernel_size, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=3)
        self.fc = nn.Linear(448, hidden_size)  
        self.cp_out = nn.Linear(hidden_size, 1)
        self.cn_out = nn.Linear(hidden_size, 1)
        self.x0_out = nn.Linear(hidden_size, 1)
        self.y0_out = nn.Linear(hidden_size, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)

    def forward(self, x):
        # (batch_size, seq_len, num_features) to (batch_size, num_features, seq_len)
        x = x.transpose(1, 2)  # (batch_size, num_features, seq_len)
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1) 
        x = torch.relu(self.fc(x))
        Cp = torch.relu(self.cp_out(x)) 
        Cn = torch.relu(self.cn_out(x)) 
        x0 = torch.relu(self.x0_out(x))
        y0 = torch.relu(self.y0_out(x))
        return Cp, Cn, x0, y0

class EncoderTransformer(nn.Module):
    def __init__(self, num_features=3, seq_len=200, hidden_size=64, num_layers=1, nhead=3, dim_feedforward=64, dropout=0.05):
        super(EncoderTransformer, self).__init__()
        # Transformer encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=num_features,  # 
            nhead=nhead,  # 
            dim_feedforward=dim_feedforward,  # 
            dropout=dropout  # dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(num_features * seq_len, hidden_size) 
        self.cp_out = nn.Linear(hidden_size, 1)
        self.cn_out = nn.Linear(hidden_size, 1)
        self.x0_out = nn.Linear(hidden_size, 1)
        self.y0_out = nn.Linear(hidden_size, 1)
        
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
    
    def forward(self, x):
        # (batch_size, seq_len, num_features)
        # (seq_len, batch_size, num_features) for Transformer 
        x = x.transpose(0, 1)
        x = self.transformer_encoder(x)
        #  (batch_size, num_features * seq_len)
        x = x.transpose(0, 1).reshape(x.size(1), -1)
        x = torch.relu(self.fc(x))
        Cp = torch.relu(self.cp_out(x)) 
        Cn = torch.relu(self.cn_out(x)) 
        x0 = torch.relu(self.x0_out(x))
        y0 = torch.relu(self.y0_out(x))
        return Cp, Cn, x0, y0

class EncoderMLP(nn.Module):
    def __init__(self, num_feature=3, hidden_size=128, output_size=64):
        super(EncoderMLP, self).__init__()
        input_size = 200 * num_feature  # 
        self.fc1 = nn.Linear(input_size, int(1*hidden_size))
        self.fc2 = nn.Linear(int(1*hidden_size), int(1*hidden_size))
        self.fc3 = nn.Linear(int(1*hidden_size), output_size)

        self.cp_out = nn.Linear(output_size, 1)
        self.cn_out = nn.Linear(output_size, 1)
        self.x0_out = nn.Linear(output_size, 1)
        self.y0_out = nn.Linear(output_size, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
        
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # Flatten (n, 200*num_feature)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        Cp = torch.relu(self.cp_out(x)) 
        Cn = torch.relu(self.cn_out(x)) 
        x0 = torch.relu(self.x0_out(x)) 
        y0 = torch.relu(self.y0_out(x)) 
        return Cp, Cn, x0, y0

class Decoder(nn.Module):
    def __init__(self, num_points=200):
        super(Decoder, self).__init__()
        self.num_points = num_points  # 
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fcv = nn.Linear(256, num_points)  # 
        self.fcq = nn.Linear(256, num_points)  #
    
    def forward(self, Cp, Cn, x0, y0):
        
        x = torch.cat((Cp, Cn, x0, y0), dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        V_OCV_curve = self.fcv(x)  
        Q_OCV_curve = self.fcq(x) 
        OCV_Q_curve = torch.stack((V_OCV_curve, Q_OCV_curve), dim=-1)
        return OCV_Q_curve


class BatteryModel(nn.Module):
    def __init__(self,encoder_name, num_points=200):
        super(BatteryModel, self).__init__()
        if encoder_name == 'MLP':
            self.encoder = EncoderMLP()
        elif encoder_name == '1DCNN':
            self.encoder = Encoder1DCNN()
        elif encoder_name == 'RNN':
            self.encoder = EncoderRNN()
        elif encoder_name == 'LSTM':
            self.encoder = EncoderLSTM()
        elif encoder_name == 'GRU':
            self.encoder = EncoderGRU()
        elif encoder_name == 'Transformer':
            self.encoder = EncoderTransformer()
        else:
            raise ValueError(f"Unknown encoder name: {encoder_name}")
        self.decoder = Decoder(num_points=num_points)

    def forward(self, x):
        Cp, Cn, x0, y0= self.encoder(x)
        OCV_Q_curve = self.decoder(Cp, Cn, x0, y0)
        # Calculate the OCV using the equation
        SOC_p = y0 - (OCV_Q_curve[:,:,1] / Cp)  # Ensure no division by zero
        SOC_n = x0 - (OCV_Q_curve[:,:,1]  / Cn)
        # Calculate OCV using OCP_p and OCP_n interpolated functions
        Calculated_OCV = OCP_p(SOC_p) - OCP_n(SOC_n)
        Calculated_OCV = Calculated_OCV/4.2
        
        return OCV_Q_curve, Cp, Cn, x0, y0, Calculated_OCV, SOC_p, SOC_n#, Qt

def regression_loss(predicted_ocv_Q_curve, true_ocv_Q_curve):
    return nn.MSELoss()(predicted_ocv_Q_curve[:,:,0], true_ocv_Q_curve[:,:,0])+\
        nn.MSELoss()(predicted_ocv_Q_curve[:,:,1], true_ocv_Q_curve[:,:,1])

def physics_loss( Cp, Cn, x0, y0, predicted_ocv_Q_curve,true_ocv_Q_curve, eps=1e-3):
    # print(y0)
    SOC_p = y0 - (predicted_ocv_Q_curve[:,:,1] / Cp)  #
    SOC_n = x0 - (predicted_ocv_Q_curve[:,:,1] / Cn)
    Calculated_OCV = OCP_p(SOC_p) - OCP_n(SOC_n)
    Calculated_OCV = Calculated_OCV/4.2
    predict_range_loss = (
        torch.mean(torch.relu(Calculated_OCV[:, :] - 1))+
        torch.mean(torch.relu(2.7/4.2-Calculated_OCV[:, :]))+
        torch.mean(torch.relu(predicted_ocv_Q_curve[:,:,0] - 1))+
        torch.mean(torch.relu(2.7/4.2-predicted_ocv_Q_curve[:,:,0]))+
        torch.mean(torch.relu(-predicted_ocv_Q_curve[:, :, 1]))+
        torch.mean(torch.relu(predicted_ocv_Q_curve[:, 0, 1]))
        )
    
    soc_p_constraint = torch.mean(torch.relu(SOC_p - 1.0) + torch.relu(0 - SOC_p))
    soc_n_constraint = torch.mean(torch.relu(SOC_n - 1.0) + torch.relu(0 - SOC_n))
    # Cp and Cn in reasonable ranges
    cp_constraint = torch.mean(torch.relu(Cp - 1.1) + torch.relu(0.1 - Cp)) #
    cn_constraint = torch.mean(torch.relu(Cn - 1.1) + torch.relu(0.1 - Cn))#
    ocv_loss = nn.MSELoss()(Calculated_OCV, predicted_ocv_Q_curve[:,:,0])
    # Total physical loss
    total_physics_loss = (
        ocv_loss 
    
    )
    
    total_constrain_loss = (

        + cp_constraint 
        + cn_constraint 
        + predict_range_loss
        + soc_p_constraint 
        + soc_n_constraint  
        
    )
    
    return total_physics_loss, total_constrain_loss


def total_loss(predicted_ocv_Q_curve, true_ocv_Q_curve, Cp, Cn, x0, y0, phys_weight, reg_weight, constraint_weight): #, reg_weight=1.0, phys_weight=1.0
    reg_loss = regression_loss(predicted_ocv_Q_curve, true_ocv_Q_curve)
    phys_loss, constraint_loss = physics_loss(Cp, Cn, x0, y0, predicted_ocv_Q_curve,true_ocv_Q_curve)
    total_loss_value = reg_weight * reg_loss + phys_weight * phys_loss + constraint_weight * constraint_loss
    return total_loss_value, reg_loss, phys_loss, constraint_loss

def error_evaluation(prediction, real_value):
    RMSE = metrics.mean_squared_error(prediction, real_value)**0.5
    MAE = metrics.mean_absolute_error(prediction, real_value)
    print("RMSE, MAE", RMSE, MAE)
    return RMSE, MAE

#%%
set_random_seed(123)
train_inputs_tensor = torch.tensor(train_inputs, dtype=torch.float32).to(device)
train_outputs_tensor = torch.tensor(train_outputs, dtype=torch.float32).to(device)  #
test_inputs_tensor = torch.tensor(test_inputs, dtype=torch.float32).to(device)
test_outputs_tensor = torch.tensor(test_outputs, dtype=torch.float32).to(device)  # 

num_samples = train_inputs_tensor.shape[0]
shuffled_indices = torch.randperm(num_samples)
train_inputs_shuffled = train_inputs_tensor[shuffled_indices]
train_outputs_shuffled = train_outputs_tensor[shuffled_indices]

val_split = 0.2  # 
num_train_samples = int((1 - val_split) * num_samples)

train_inputs_train = train_inputs_shuffled[:num_train_samples]
train_outputs_train = train_outputs_shuffled[:num_train_samples]
train_inputs_val = train_inputs_shuffled[num_train_samples:]
train_outputs_val = train_outputs_shuffled[num_train_samples:]
#%% diagnostic model
set_random_seed(123)
patience = 2000  # 
no_improvement = 0  # 

num_points = 200  # 
model = BatteryModel(encoder_name = 'MLP', num_points=num_points).to(device)  # 
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
# Add these parameters to the optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=100)

best_loss = float('inf')  # 
best_model_path = 'best_model.pth' 
max_epoch = 20000
phys_weight = 1
reg_weight = 1
constraint_weight = 0.1
total_start_time = time.time()
for epoch in range(max_epoch):
  
    model.train()
    predicted_Q_ocv_curve, Cp, Cn, x0, y0,Calculated_Q_ocv, SOC_p, SOC_n = model(train_inputs_train)
    total_loss_value, reg_loss, phys_loss,constraint_loss = total_loss(predicted_Q_ocv_curve, train_outputs_train,
                                                       Cp, Cn, x0, y0, phys_weight, reg_weight,constraint_weight) #, 
    
    # Backward pass and optimization
    optimizer.zero_grad()
    total_loss_value.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    # Validation phase
    model.eval()
    with torch.no_grad():
        predicted_val_Q_ocv, Cp, Cn, x0, y0,Calculated_val_Q_ocv, SOC_p, SOC_n = model(train_inputs_val)
        val_total_loss_value, val_reg_loss, val_phys_loss,val_constraint_loss = total_loss(predicted_val_Q_ocv, train_outputs_val,
                                                                       Cp, Cn, x0, y0, phys_weight, reg_weight,constraint_weight) #, reg_weight=1.0, physics_weight=5
    optimizer.step()
    scheduler.step(val_total_loss_value)
    # Early stopping logic with best validation loss tracking
    if val_total_loss_value < best_loss:
        # time.sleep(0.5)
        best_loss = val_total_loss_value
        no_improvement = 0  # Reset counter if improvement
        torch.save(model.state_dict(), best_model_path)  # Save the best model
        print(f'Epoch {epoch}, '
              f'Training Loss: {total_loss_value.item()} (Reg: {reg_loss.item()}, Phys: {phys_loss.item()}, Con: {constraint_loss.item()})\n'
              f'Validation Loss: {val_total_loss_value.item()} (Reg: {val_reg_loss.item()}, Phys: {val_phys_loss.item()}, Con: {val_constraint_loss.item()}), LR: {scheduler.optimizer.param_groups[0]["lr"]:.6f}')
    else:
        no_improvement += 1
        
    # Early stopping based on patience
    if no_improvement >= patience or scheduler.optimizer.param_groups[0]["lr"] <5e-7:
        print(f"Early stopping at epoch {epoch}, Validation loss has not improved in {patience} epochs.")
        break
    
total_end_time = time.time()
print('Time for training:', total_end_time-total_start_time)

#%% diangostic results
model.load_state_dict(torch.load(best_model_path, weights_only=True))
model.eval()  # 
with torch.no_grad():
    predicted_test_ocv, Cp, Cn, x0, y0,Calculated_test_ocv, SOC_p, SOC_n = model(test_inputs_tensor)

predicted_ocv_curve_test = predicted_test_ocv.cpu().numpy()
Calculated_test_ocv = Calculated_test_ocv.cpu().numpy()
true_ocv_curve_test = test_outputs_tensor.cpu().numpy()
Cp = Cp.cpu().numpy()
Cn = Cn.cpu().numpy()
x0 = x0.cpu().numpy()
y0 = y0.cpu().numpy()
Cap_real = true_ocv_curve_test[:,-1,1] 
Cap_predict = predicted_ocv_curve_test[:,-1,1] 

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
cmap2 = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#facaa9','#558c83'])

fig, ax = plt.subplots(figsize=(6.5/2.54, 5/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
orig_color_values = np.abs(Cap_real-np.array(Cap_predict))
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max()+1)
color_values = norm(orig_color_values)
# ax.tick_params(top='on', right='on', which='both')
scatter = plt.scatter(Cap_real, Cap_predict, c=np.abs(Cap_real-Cap_predict), alpha=1, cmap=cmap, marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(Cap_real,Cap_real,'-',color='dimgray')
divider = make_axes_locatable(ax)
cax = ax.inset_axes([0.7, 0.08, 0.05, 0.35])
cbar = plt.colorbar(scatter, cax=cax)
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)
plt.xlabel('Real',labelpad=1)
plt.ylabel('Predictions',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
plt.show()
    
RMSE_predicted,MAE_predicted = error_evaluation(Cap_predict.reshape(-1),Cap_real.reshape(-1))
residuals = Cap_predict-Cap_real
ss_res = np.sum(residuals**2)
ss_tot = np.sum((Cap_real - np.mean(Cap_real))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


pre_error=Cap_predict-Cap_real
fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
scatter = plt.scatter(100*test_soc_range_work, pre_error, c=np.abs(Cap_real-Cap_predict), alpha=0.8, cmap=cmap, marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.xlabel('SOC range [%]',labelpad=1)
plt.ylabel('SOH error',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)
plt.show()


fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
scatter = plt.scatter(test_start_v, test_end_v, c=np.abs(Cap_real-Cap_predict), alpha=0.8, cmap=cmap, marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.xlabel('Start voltage [V]',labelpad=1)
plt.ylabel('End voltage [V]',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)
plt.show()


plt.figure(num=None,figsize=(6/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(Cap_real, test_inputs_tensor.cpu().numpy()[:,-1,1] ,'o')
# plt.plot(Cap_real, Cap_real)
plt.xlabel('Real')
plt.ylabel('Measurement')
plt.show()
# fixed_voltage_points = np.linspace(2.8, 4.2, num_points)
RMSE_prior,MAE_prior = error_evaluation(test_inputs_tensor.cpu().numpy()[:,-1,1].reshape(-1),Cap_real.reshape(-1))

residuals = np.array(test_inputs_tensor.cpu().numpy()[:,-1,2].reshape(-1))-Cap_real
ss_res = np.sum(residuals**2)
ss_tot = np.sum((Cap_real - np.mean(Cap_real))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


error = np.abs(Cap_real - Cap_predict)
error_relative = np.abs(Cap_real - Cap_predict)/(Cap_real)
sorted_error = np.sort(error)
cumulative_distribution = np.linspace(0, 1, len(sorted_error))
sorted_error_relative = np.sort(error_relative)
cumulative_distribution_relative = np.linspace(0, 1, len(sorted_error_relative))

abs_idx = np.abs(cumulative_distribution - 0.95).argmin()
abs_error_x = 100 * sorted_error[abs_idx]
abs_error_y = cumulative_distribution[abs_idx]
rel_idx = np.abs(cumulative_distribution_relative - 0.95).argmin()
rel_error_x = 100 * sorted_error_relative[rel_idx]
rel_error_y = cumulative_distribution_relative[rel_idx]

plt.figure(num=None,figsize=(6.5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.axhline(y=0.95, color='#daaa89', linestyle='--', linewidth=2, label='95%')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(100*sorted_error, cumulative_distribution, label='Absolute error', color='#5ca7d1', linewidth=2)
plt.plot(100*sorted_error_relative, cumulative_distribution_relative, label='Relative error', color='#f1b2c2', linewidth=2)
plt.scatter([abs_error_x], [abs_error_y], color='#5ca7d1', s=10, zorder=5)
plt.axvline(x=abs_error_x, color='gray', linestyle='--',  alpha=0.5,linewidth=1)
plt.scatter([rel_error_x], [rel_error_y], color='#f1b2c2', s=10, zorder=5)
plt.axvline(x=rel_error_x, color='gray', linestyle='--', alpha=0.5, linewidth=1)
plt.xlabel('SOH error [%]')
plt.ylabel('Cumulative distribution')
legend = plt.legend(frameon=True,labelspacing=0.01,loc='lower right',facecolor='white', edgecolor='white')
legend.get_frame().set_linewidth(0) 
plt.show()

error_matrix = abs(predicted_ocv_curve_test[:,:,1]*nominal_capacity-true_ocv_curve_test[:,:,1]*nominal_capacity)
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix, cmap=cmap2, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(predicted_ocv_curve_test[0, :, 0]) - 1, 10).astype(int)  
voltage_values = predicted_ocv_curve_test[0, :, 0] * 4.2  
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Voltage [V]')
plt.ylabel('Test Sample')
plt.show()

RMSE_predicted,MAE_predicted = error_evaluation(np.array(predicted_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity),
                                                np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity))

residuals = np.array(predicted_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)-np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity) - np.mean(np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


error_matrix2 = abs(4.2*predicted_ocv_curve_test[:,1:,0]-4.2*true_ocv_curve_test[:,1:,0])
error_matrix3 = abs(4.2*Calculated_test_ocv[:,1:]-4.2*true_ocv_curve_test[:,1:,0])

vmin = min(np.min(error_matrix2), np.min(error_matrix3))  # 
vmax = max(np.max(error_matrix2), np.max(error_matrix3))  #

plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix2, cmap=cmap, xticklabels=8, vmin=vmin, vmax=vmax)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(true_ocv_curve_test[0, 1:, 1]) - 1, 10).astype(int) 
voltage_values = true_ocv_curve_test[0, 1:, 1] * nominal_capacity  
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Q [Ah]')
# plt.ylabel('Test Sample Index')
ax.set_yticklabels([])
ax.set_yticks([])
plt.show()


RMSE_predicted,MAE_predicted = error_evaluation(predicted_ocv_curve_test[:,1:,0].reshape(-1)*4.2,
                                                true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
residuals = np.array(predicted_ocv_curve_test[:,1:,0].reshape(-1)*4.2)-np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2) - np.mean(np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix3, cmap=cmap, xticklabels=8, vmin=vmin, vmax=vmax)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(true_ocv_curve_test[0, 1:, 1]) - 1, 10).astype(int) 
voltage_values = true_ocv_curve_test[0, 1:, 1] * nominal_capacity  
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Q [Ah]')
# plt.ylabel('Test Sample Index')
ax.set_yticklabels([])
ax.set_yticks([])
plt.show()


RMSE_predicted,MAE_predicted = error_evaluation(np.array(Calculated_test_ocv[:,1:].reshape(-1)*4.2),
                                                true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
residuals = np.array(Calculated_test_ocv[:,1:].reshape(-1)*4.2)-np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2) - np.mean(np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


#%% SHAP analysis for encoder in health diagnosis

x = test_inputs_tensor.to(device)

class CpModel(torch.nn.Module):
    def __init__(self, model):
        super(CpModel, self).__init__()
        self.model = model
    
    def forward(self, x):
        _, Cp, _, _, _, _, _, _ = self.model(x)
        return Cp

class CnModel(torch.nn.Module):
    def __init__(self, model):
        super(CnModel, self).__init__()
        self.model = model
    
    def forward(self, x):
        _, _, Cn, _, _, _, _, _ = self.model(x)
        return Cn

class X0Model(torch.nn.Module):
    def __init__(self, model):
        super(X0Model, self).__init__()
        self.model = model
    
    def forward(self, x):
        _, _, _, x0, _, _, _, _ = self.model(x)
        return x0

class Y0Model(torch.nn.Module):
    def __init__(self, model):
        super(Y0Model, self).__init__()
        self.model = model
    
    def forward(self, x):
        _, _, _, _, y0, _, _, _ = self.model(x)
        return y0

cp_model = CpModel(model).to(device)
cn_model = CnModel(model).to(device)
x0_model = X0Model(model).to(device)
y0_model = Y0Model(model).to(device)

cp_explainer = shap.DeepExplainer(cp_model, x)
cn_explainer = shap.DeepExplainer(cn_model, x)
x0_explainer = shap.DeepExplainer(x0_model, x)
y0_explainer = shap.DeepExplainer(y0_model, x)


cp_shap_values = cp_explainer.shap_values(x, check_additivity=False)
cn_shap_values = cn_explainer.shap_values(x, check_additivity=False)
x0_shap_values = x0_explainer.shap_values(x, check_additivity=False)
y0_shap_values = y0_explainer.shap_values(x, check_additivity=False)

# cp_shap_values_squeezed = cp_shap_values.squeeze(-1)
# cn_shap_values_squeezed = cn_shap_values.squeeze(-1)
# x0_shap_values_squeezed = x0_shap_values.squeeze(-1)
# y0_shap_values_squeezed = y0_shap_values.squeeze(-1)

cp_shap_values_squeezed = cp_shap_values.squeeze() if cp_shap_values.shape[-1] == 1 else cp_shap_values
cn_shap_values_squeezed = cn_shap_values.squeeze() if cn_shap_values.shape[-1] == 1 else cn_shap_values
x0_shap_values_squeezed = x0_shap_values.squeeze() if x0_shap_values.shape[-1] == 1 else x0_shap_values
y0_shap_values_squeezed = y0_shap_values.squeeze() if y0_shap_values.shape[-1] == 1 else y0_shap_values

cp_shap_values_aggregated = np.mean(cp_shap_values_squeezed, axis=1)
cn_shap_values_aggregated = np.mean(cn_shap_values_squeezed, axis=1)
x0_shap_values_aggregated = np.mean(x0_shap_values_squeezed, axis=1)
y0_shap_values_aggregated = np.mean(y0_shap_values_squeezed, axis=1)

x_aggregated = np.mean(x.cpu().numpy(), axis=1)

cmap2 = mcolors.LinearSegmentedColormap.from_list('custom_cmap',[ '#8da0cb', '#66c2a5'])
feature_names = ["V", "Q", "EFC"]
plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
shap.summary_plot(cp_shap_values_aggregated, 
                  features=x_aggregated, 
                  feature_names=feature_names,
                  max_display=4,plot_size=(10/2.54, 5/2.54),
                  cmap=cmap2)  #
plt.show()

plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
shap.summary_plot(cn_shap_values_aggregated, 
                  features=x_aggregated, 
                  feature_names=feature_names,
                  max_display=4,plot_size=(10/2.54, 5/2.54),
                  cmap=cmap2)  #
plt.show()

plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
shap.summary_plot(x0_shap_values_aggregated, 
                  features=x_aggregated, 
                  feature_names=feature_names,
                  max_display=4,plot_size=(10/2.54, 5/2.54),
                  cmap=cmap2)  #
plt.show()

plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
shap.summary_plot(y0_shap_values_aggregated, 
                  features=x_aggregated, 
                  feature_names=feature_names,
                  max_display=4,plot_size=(10/2.54, 5/2.54),
                  cmap=cmap2)  # 
plt.show()


shap_values_list = [
    cp_shap_values_aggregated, 
    cn_shap_values_aggregated, 
    x0_shap_values_aggregated, 
    y0_shap_values_aggregated
]

shap_values_stacked = np.stack(shap_values_list, axis=0)  #  (4, num_samples, 3)


colors = ['#fb8d62', '#8da0cb', '#66c2a5']
feature_names = ["V", "Q", "EFC"]
parameter_names = ["Cp", "Cn", "x0", "y0"]

fig, ax = plt.subplots(num=None,figsize=(10/2.54,6/2.54),dpi=600)
# plt.figure(num=None,figsize=(10/2.54,4/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

for i in range(4):  # 4 
    for j in range(3):  # 3 
        plt.boxplot(shap_values_stacked[i, :, j], positions=[i * 4 + j], widths=0.6, patch_artist=True,
                    boxprops=dict(facecolor=colors[j], color=colors[j]),
                    medianprops=dict(color='grey'),
                    whiskerprops=dict(color=colors[j]),
                    capprops=dict(color=colors[j]),
                    flierprops=dict(marker='o', color='grey', markerfacecolor='grey', markeredgecolor='grey',linewidth=0,markersize=2, alpha=0.6))

plt.xticks([i * 4 + 1 for i in range(4)], parameter_names)
plt.xlabel("Parameters")
plt.ylabel("SHAP Value")
custom_legend = [plt.Line2D([0], [0], color=colors[i], lw=6) for i in range(3)]
plt.legend(custom_legend, feature_names, frameon=False,labelspacing=0.01, fontsize=12, handlelength=1)
plt.tight_layout()
plt.show()


shap_values_stacked = np.abs(shap_values_stacked)
means = np.mean(shap_values_stacked, axis=1)  # 
stds = np.std(shap_values_stacked, axis=1)    #
fig, ax = plt.subplots(num=None,figsize=(10/2.54,5.5/2.54),dpi=600)
# plt.figure(num=None,figsize=(10/2.54,4/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
bar_width = 0.2
x = np.arange(len(parameter_names))  

for i in range(3):  # 
    ax.bar(x + i * bar_width, means[:, i], bar_width, color=colors[i], label=feature_names[i])
    
    ax.errorbar(x + i * bar_width, means[:, i], yerr=stds[:, i], fmt='none', color='grey', capsize=3, elinewidth=0.8)

ax.set_xticks(x + bar_width)
ax.set_xticklabels(parameter_names)
ax.set_xlabel("Parameters")
ax.set_ylabel("Absolute SHAP Value")
plt.legend(frameon=False,labelspacing=0.05, fontsize=12)
plt.tight_layout()
plt.show()


#%% SHAP for decoder in health diagnosis
x = test_inputs_tensor.to(device)
Cp, Cn, x0, y0= model.encoder(x)

# 
class DecoderModel(torch.nn.Module):
    def __init__(self, decoder):
        super(DecoderModel, self).__init__()
        self.decoder = decoder
    
    def forward(self, inputs):
        # 
        Cp, Cn, x0, y0 = torch.split(inputs, [1, 1, 1, 1], dim=-1)
        OCV_Q_curve = self.decoder(Cp, Cn, x0, y0)
        # 
        V_OCV_curve_mean = torch.mean(OCV_Q_curve[:, :, 0], dim=1)
        Q_OCV_curve_mean = torch.mean(OCV_Q_curve[:, :, 1], dim=1)
        # 
        combined_output = torch.cat((V_OCV_curve_mean.unsqueeze(-1), Q_OCV_curve_mean.unsqueeze(-1)), dim=-1)
        return combined_output


decoder_model = DecoderModel(model.decoder).to(device)

decoder_inputs = torch.cat((Cp, Cn, x0, y0), dim=-1)

decoder_explainer = shap.DeepExplainer(decoder_model, decoder_inputs)
decoder_shap_values = decoder_explainer.shap_values(decoder_inputs, check_additivity=False)

V_OCV_curve_shap_values = decoder_shap_values[0] #[:,:,0]
Q_OCV_curve_shap_values = decoder_shap_values[1] #[:,:,1]

feature_names = ["Cp", "Cn", "x0", "y0"]

cmap2 = mcolors.LinearSegmentedColormap.from_list('custom_cmap',[ '#8da0cb', '#66c2a5'])
plt.figure(num=None,figsize=(10/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
shap.summary_plot(Q_OCV_curve_shap_values, 
                  features=decoder_inputs.cpu().detach().numpy(), 
                  feature_names=feature_names,
                  max_display=4,plot_size=(10/2.54, 5/2.54),
                  cmap=cmap2)  #
plt.show()

#%% Show the fitting results of health diagnosis

# PSO optimization function
def pso_objective_function(params):
    Cp, Cn, x0, y0 = params
    fitted_Voc = np.zeros_like(predict_Q)
    OCPp_curve = np.zeros_like(predict_Q)
    OCPn_curve = np.zeros_like(predict_Q)
    for j, Qt in enumerate(predict_Q):
        SOC_p = y0 - Qt / Cp
        SOC_n = x0 - Qt / Cn
        
        SOC_p_tensor = torch.tensor([SOC_p], dtype=torch.float32).to('cuda')
        SOC_n_tensor = torch.tensor([SOC_n], dtype=torch.float32).to('cuda')
        
        Up = OCP_p(SOC_p_tensor).item()
        Un = OCP_n(SOC_n_tensor).item()
        
        OCPp_curve[j] = Up
        OCPn_curve[j] = Un
        fitted_Voc[j] = Up - Un

    return np.mean((predict_V - fitted_Voc) ** 2)  # MSE


# min_row_index = np.argmin(np.mean(error_matrix, axis=1)+np.mean(error_matrix2, axis=1))
min_row_index = np.argmin(np.mean(error_matrix, axis=1))
max_row_index = np.argmax(np.max(error_matrix, axis=1))
median_row_index = np.argmin(np.abs(np.mean(error_matrix, axis=1) - np.median(np.mean(error_matrix, axis=1))))
# Loop through data points for PSO optimization
for i in range(1):  # 0, len(predicted_ocv_curve_test), 200
    i = min_row_index
    
    plt.figure(num=None,figsize=(7/2.54,6/2.54),dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(test_inputs_tensor[i, :, 1].cpu().numpy() * nominal_capacity, test_inputs_tensor[i, :, 0].cpu().numpy() * 4.2, 
             label='Measure V-Q', color='grey', linestyle='-')
    plt.plot(true_ocv_curve_test[i, :, 1]*nominal_capacity, true_ocv_curve_test[i, :, 0]*4.2, label='True OCV', color='blue', linestyle='-')
    plt.plot(predicted_ocv_curve_test[i, :,1]*nominal_capacity, predicted_ocv_curve_test[i, :,0]*4.2, label='Predicted OCV', color='red', linestyle='--')
    plt.plot(predicted_ocv_curve_test[i, :,1]*nominal_capacity, Calculated_test_ocv[i, :]*4.2, label='Derived OCV',  color='orange', linestyle='-.')
    plt.xlabel('Q [Ah]')
    plt.ylabel('V [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False,labelspacing=0.01)
    # plt.grid(True)
    plt.show()
    
    
    predict_Q = predicted_ocv_curve_test[i, :, 1]
    predict_V = predicted_ocv_curve_test[i, :, 0] * 4.2
    
    Cp_con = Cp[i].item()
    Cn_con = Cn[i].item()
    x0_con = x0[i].item()
    y0_con = y0[i].item()
    
    lb = [0.1, 0.1, 1e-3, 1e-3]  # Lower bounds
    ub = [1.1, 1.1, 1.0, 1.0]  # Upper bounds
    # Set PSO options, including max iterations to limit the run time
    options = {'maxiter': 20}
    optimized_params, fopt = pso(pso_objective_function, lb, ub, minstep=1e-5, minfunc=1e-5, **options)
    
    print(f"Constrianed parameters: Cp={Cp_con}, Cn={Cn_con}, x0={x0_con}, y0={y0_con}")
    print(f"PSO optimized parameters: Cp={optimized_params[0]}, Cn={optimized_params[1]}, x0={optimized_params[2]}, y0={optimized_params[3]}")
    
    Cp_opt, Cn_opt, x0_opt, y0_opt = optimized_params
    fitted_Voc = np.zeros_like(predict_Q)
    OCPp_curve_fit = np.zeros_like(predict_Q)
    OCPn_curve_fit = np.zeros_like(predict_Q)
    OCPp_curve_con = np.zeros_like(predict_Q)
    OCPn_curve_con = np.zeros_like(predict_Q)
    Con_Voc = np.zeros_like(predict_Q)
    for j, Qt in enumerate(predict_Q):
        SOC_p_fit = y0_opt - Qt / Cp_opt
        SOC_n_fit = x0_opt - Qt / Cn_opt
        
        SOC_p_tensor_fit = torch.tensor([SOC_p_fit], dtype=torch.float32).to('cuda')
        SOC_n_tensor_fit = torch.tensor([SOC_n_fit], dtype=torch.float32).to('cuda')
        
        Up_fit = OCP_p(SOC_p_tensor_fit).item()
        Un_fit = OCP_n(SOC_n_tensor_fit).item()
        
        OCPp_curve_fit[j] = Up_fit
        OCPn_curve_fit[j] = Un_fit
        fitted_Voc[j] = Up_fit - Un_fit
        
        
        SOC_p_con = y0_con - Qt / Cp_con
        SOC_n_con = x0_con - Qt / Cn_con
        
        SOC_p_tensor_con = torch.tensor([SOC_p_con], dtype=torch.float32).to('cuda')
        SOC_n_tensor_con = torch.tensor([SOC_n_con], dtype=torch.float32).to('cuda')
        
        Up_con = OCP_p(SOC_p_tensor_con).item()
        Un_con = OCP_n(SOC_n_tensor_con).item()
        
        OCPp_curve_con[j] = Up_con
        OCPn_curve_con[j] = Un_con
        Con_Voc[j] = Up_con - Un_con
    
    # Plot results
    fig, ax = plt.subplots(figsize=(8/2.54, 5/2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * 4.84, true_ocv_curve_test[i, :, 0] * 4.2, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * 4.84, predict_V, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * 4.84, Calculated_test_ocv[i, :]*4.2, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPn_curve_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, fitted_Voc, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_fit, ':', label='Fitted OCPp',color='#78A040',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPn_curve_fit, ':', label='Fitted OCPn',color='#ED98CC',linewidth=2)
    plt.xlabel('Q [Ah]')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
    ax_inset = inset_axes(ax, width="50%", height="40%", 
                          bbox_to_anchor=(-0.3, -0.35, 1, 1.1),  # 
                          bbox_transform=ax.transAxes)
    ax_inset.tick_params(bottom=False, left=False)
    ax_inset.plot(true_ocv_curve_test[i, :, 1] * 4.84, true_ocv_curve_test[i, :, 0] * 4.2, '-', label='Measured OCV', color='grey',linewidth=2)
    ax_inset.plot(predict_Q * 4.84, predict_V, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    ax_inset.plot(predict_Q * 4.84, Calculated_test_ocv[i, :]*4.2, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, fitted_Voc, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_fit, ':', label='Fitted OCPp',color='#78A040',linewidth=2)
    # ax_inset.yaxis.set_ticklabels([]) 
    ax_inset.tick_params(axis='x', labelsize=10.5)
    ax_inset.tick_params(axis='y', labelsize=10.5)
    ax_inset.set_ylim(3.7,4)
    ax_inset.set_xlim(1,2.4)
    # ax_inset.set_ylabel('Count', labelpad=-2., fontsize=10.5)
    # plt.grid(True)
    plt.show()
    
    
    dv_dq_orig = gradient( true_ocv_curve_test[i, :, 0] * 4.2, true_ocv_curve_test[i, :, 1] * 4.84)
    dv_dq_predict = gradient(predict_V, predict_Q* 4.84)
    dv_dq_fit = gradient(fitted_Voc, predict_Q* 4.84)
    dv_dq_calculate = gradient(Calculated_test_ocv[i, :]*4.2, predict_Q* 4.84)
    dv_dq_cathode_fit = gradient(OCPp_curve_fit, predict_Q* 4.84)
    dv_dq_anode_fit = gradient(OCPn_curve_fit, predict_Q* 4.84)
    dv_dq_cathode_con = gradient(OCPp_curve_con, predict_Q* 4.84)
    dv_dq_anode_con = gradient(OCPn_curve_con, predict_Q* 4.84)
    
    plt.figure(figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * 4.84, dv_dq_orig, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_predict, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_calculate, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_cathode_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_anode_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_fit, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_cathode_fit, ':', label='Fitted OCPp', color='#78A040',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_anode_fit, ':', label='Fitted OCPn', color='#ED98CC',linewidth=2)
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    plt.ylim([-1.5,1.1])
    # plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
    plt.show()
    
    
    plt.figure(num=None, figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * 4.84, true_ocv_curve_test[i, :, 0] * 4.2, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * 4.84, predict_V, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * 4.84, Calculated_test_ocv[i, :]*4.2, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    # plt.plot(predict_Q * 4.84, OCPp_curve_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, fitted_Voc, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    # plt.plot(predict_Q * 4.84, OCPp_curve_fit, ':', label='Fitted OCPp',color='#78A040',linewidth=2)
    plt.xlabel('Q [Ah]')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
    # plt.grid(True)
    plt.xlim([1,3])
    plt.ylim([3.5,4.1])
    plt.show()
    
    plt.figure(num=None, figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(predict_Q * 4.84, OCPp_curve_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_fit, ':', label='Fitted OCPp',color='#78A040',linewidth=2)
    plt.xlabel('Q [Ah]')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
    # plt.grid(True)
    plt.xlim([1,3])
    plt.ylim([3.7,4.2])
    plt.show()
    
    
    plt.figure(num=None, figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(predict_Q * 4.84, OCPn_curve_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPn_curve_fit, ':', label='Fitted OCPn',color='#ED98CC',linewidth=2)
    plt.xlabel('Q [Ah]')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
    # plt.grid(True)
    plt.xlim([1,3])
    plt.ylim([0.15,0.23])
    plt.show()
    
    plt.figure(figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * 4.84, dv_dq_orig, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_predict, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_calculate, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    # plt.plot(predict_Q * 4.84, dv_dq_cathode_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    # plt.plot(predict_Q * 4.84, dv_dq_anode_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_fit, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    # plt.plot(predict_Q * 4.84, dv_dq_cathode_fit, ':', label='Fitted OCPp', color='#78A040',linewidth=2)
    # plt.plot(predict_Q * 4.84, dv_dq_anode_fit, ':', label='Fitted OCPn', color='#ED98CC',linewidth=2)
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    plt.ylim([-0.3,-0.05])
    plt.xlim([1,3])
    # plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
    plt.show()
    
    plt.figure(figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(predict_Q * 4.84, dv_dq_cathode_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_cathode_fit, ':', label='Fitted OCPp', color='#78A040',linewidth=2)
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    plt.ylim([-0.3,-0.05])
    plt.xlim([1,3])
    # plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
    plt.show()
    
    plt.figure(figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(predict_Q * 4.84, dv_dq_anode_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_anode_fit, ':', label='Fitted OCPn', color='#ED98CC',linewidth=2)
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    plt.ylim([0.00,0.1])
    plt.xlim([1,3])
    # plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
    plt.show()
    
#%
#%% extract data for health prognosis

def prognostic_data_prepare (diagnostic_model, data, idx_seed, rounds_sample, num_points=200):
    
    nominal_capacity = 4.84
    discharge_capacity_curve =[]
    EFCs_curve =[]
    # print(data['protocol'])
    for rpt_data in data['rpt']:
        if '02C' in rpt_data:
            discharge_capacity_02C = np.array(rpt_data['02C']['discharge_capacity'])
            discharge_voltage_02C = np.array(rpt_data['02C']['voltage'])
            EFCs_02C = np.array(rpt_data['02C']['EFC'])
            if discharge_capacity_02C.size == 0:
                continue
            discharge_capacity_curve.append(max(discharge_capacity_02C))
            EFCs_curve.append(max(EFCs_02C))
    discharge_capacity_curve = np.array(discharge_capacity_curve)/nominal_capacity  # Shape (num_samples, num_points, num_features)
    EFCs_curve = np.array(EFCs_curve)/500  #(num_samples, num_points)
    interp_discharge_capacity_curve = interp1d(np.linspace(0, 1, len(discharge_capacity_curve)), discharge_capacity_curve, kind='linear')
    discharge_capacity_curve_interp = interp_discharge_capacity_curve(np.linspace(0, 1, 10000))
    interp_EFCs_curve = interp1d(np.linspace(0, 1, len(discharge_capacity_curve)), EFCs_curve, kind='linear')
    EFCs_curve_interp = interp_EFCs_curve(np.linspace(0, 1, 10000))
    
    if discharge_capacity_curve[-1] < 0.8:
        predict_end_capacity = 0.8 
    else:
        predict_end_capacity = discharge_capacity_curve[-1]    
    
    inputs = []
    inputs_add = []
    outputs = []
    del rpt_data

    # print(data['protocol'])
    # final_dia = 3
    if len(data['rpt'])>3:
        final_dia = 3
    else:
        final_dia = len(data['rpt'])
    # print(final_dia)
    for ii, rpt_data in enumerate(data['rpt'][2:final_dia]):
        for round_idx in range(rounds_sample):
            
            set_random_seed(idx_seed+round_idx)

            if 'work' in rpt_data and '02C' in rpt_data:
                # print('work')
                if rpt_data is data['rpt'][-1]:
                    # print('continue')
                    continue
                # Extract the 1C charging process
                charge_current = np.array(rpt_data['work']['current'])
                charge_voltage = np.array(rpt_data['work']['voltage'])
                charge_capacity = np.array(rpt_data['work']['charge_capacity'])
                EFCs = np.array(rpt_data['work']['EFC'])
                if charge_current.size < 50:
                    continue
                # Find the first point where voltage reaches 4.1V as the charging end point
                
                for charge_end_idx in range(2,len(charge_current)):
                    if charge_voltage[charge_end_idx]<charge_voltage[charge_end_idx-1]:
                        break
                
                
                
                # Randomly select start_index
                start_index = random.randint(0, int(0.3 * charge_end_idx))
                
                # Randomly select end_index
                end_index = random.randint(int(0.7 * charge_end_idx), charge_end_idx)
                
                
                # Extract charging data from start_index to end_index
                charge_current = charge_current[start_index:end_index] / nominal_capacity
                charge_voltage = charge_voltage[start_index:end_index] / 4.2
                charge_capacity = charge_capacity[start_index:end_index] / nominal_capacity-charge_capacity[start_index] / nominal_capacity
                EFCs = EFCs[start_index:end_index] / 500
                
                if charge_current.size<50 :
                    continue
                
                # Interpolation to obtain num_points data points
                x_new = np.linspace(0, 1, num_points)
                interp_current = interp1d(np.linspace(0, 1, len(charge_current)), charge_current, kind='linear')
                interp_voltage = interp1d(np.linspace(0, 1, len(charge_voltage)), charge_voltage, kind='linear')
                interp_capacity = interp1d(np.linspace(0, 1, len(charge_capacity)), charge_capacity, kind='linear')
                interp_EFCs = interp1d(np.linspace(0, 1, len(EFCs)), EFCs, kind='linear')
    
                charge_current_interp = interp_current(x_new)
                charge_voltage_interp = interp_voltage(x_new)
                charge_capacity_interp = interp_capacity(x_new)
                charge_EFCs_interp = interp_EFCs(x_new)
               
                # Extract the 0.2C discharge process (capacity and voltage) and perform interpolation
                discharge_capacity = np.array(rpt_data['02C']['discharge_capacity'])
                discharge_voltage = np.array(rpt_data['02C']['voltage'])
                
                dis_start_idx = np.where(discharge_capacity <=0)[0]
                if len(dis_start_idx) > 0:
                    dis_start_idx = dis_start_idx[-1] 
                else:
                    dis_start_idx = len(discharge_capacity)  
                
                discharge_capacity = discharge_capacity[dis_start_idx:]/nominal_capacity
                # print(max(discharge_capacity))
                discharge_voltage = discharge_voltage[dis_start_idx:]/ 4.2
                
                
                interp_ocv_v = interp1d(np.linspace(0, 1, len(discharge_voltage)), discharge_voltage, kind='linear')
                interp_ocv_q = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity, kind='linear')
                discharge_voltage_interp = interp_ocv_v(x_new)
                discharge_capacity_interp = interp_ocv_q(x_new)
                discharge_capacity_diff = discharge_capacity_interp[1:] - discharge_capacity_interp[ :-1]
                if max(discharge_capacity_diff)>0.05 or max(charge_capacity)*nominal_capacity<0.2 or max(discharge_capacity)<0.835: # or max(discharge_capacity)<0.835 or max(discharge_capacity)>0.865 
                    continue
                
                # Generate 1C charging input (200, 3))
                input_data = np.stack([charge_voltage_interp, charge_capacity_interp, charge_EFCs_interp], axis=-1) #
                output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
                
                input_data_tensor = torch.tensor(input_data.reshape(-1, input_data.shape[0],input_data.shape[1]), dtype=torch.float32).to(device)
                diagonsis_ocv, Cp, Cn, x0, y0, _, _, _ = diagnostic_model(input_data_tensor)
                diagonsis_ocv = diagonsis_ocv.cpu().detach().numpy()
                
                Cp = Cp.cpu().detach().numpy()
                Cn = Cn.cpu().detach().numpy()
                x0 = x0.cpu().detach().numpy()
                y0 = y0.cpu().detach().numpy()
                # Add the processed sample to the list
                inputs.append(diagonsis_ocv.reshape(-1,2))
               
                inputs_add.append(np.concatenate([Cp, Cn, x0, y0, diagonsis_ocv[:,-1,1].reshape(1, 1), np.array(charge_EFCs_interp[-1]).reshape(1, 1)], axis=1).reshape(-1))
                
                current_capacity = discharge_capacity[-1]
                mask = (discharge_capacity_curve_interp <= current_capacity) & (discharge_capacity_curve_interp >= predict_end_capacity)
                
                capacity_to_predict = discharge_capacity_curve_interp[mask]
                EFCs_to_predict = EFCs_curve_interp[mask]
                if len(capacity_to_predict)<1:
                    continue
                
                x_capacity = np.linspace(0, 1, len(capacity_to_predict))
                x_capacity_new = np.linspace(0, 1, 25)
                interp_func_cap_new = interp1d(x_capacity, capacity_to_predict, kind='linear')
                capacity_to_predict_interp = interp_func_cap_new(x_capacity_new)
                interp_func_EFC_new = interp1d(x_capacity, EFCs_to_predict, kind='linear')
                EFCs_to_predict_interp = interp_func_EFC_new(x_capacity_new)
                
                
                outputs.append(np.stack([capacity_to_predict_interp, EFCs_to_predict_interp], axis=-1)) 
                # outputs.append([capacity_to_predict_interp])  
        
    # Convert input and output lists to NumPy arrays
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    inputs_add = np.array(inputs_add) 
    outputs = np.array(outputs)  #(num_samples, num_points)
                
            
    return inputs, inputs_add, outputs
    
    
# Construct training and test datasets
train_inputs_prog, train_inputs_add_prog, train_outputs_prog = [], [], []
test_inputs_prog, test_inputs_add_prog, test_outputs_prog = [], [], []


# Process the battery data in the training set
for i in range(0,len(train_batteries)): # train_batteries all_batteries 
    battery = train_batteries[i]
    print("Train cell", battery)
    cap = []
    input_data, input_data_add, output_data = prognostic_data_prepare(model, data_all_cells[battery],i, rounds_sample=3)
    if input_data.size==0 :
        continue
    train_inputs_prog.append(input_data)
    train_inputs_add_prog.append(input_data_add)
    train_outputs_prog.append(output_data)


# Process the battery data in the training set
for i in range(0,len(test_batteries)): # train_batteries all_batteries 
    battery = test_batteries[i]
    print("Test cell", battery)
    cap = []
    input_data, input_data_add, output_data = prognostic_data_prepare(model, data_all_cells[battery],i, rounds_sample=1)
    if input_data.size==0 :
        continue
    test_inputs_prog.append(input_data)
    test_inputs_add_prog.append(input_data_add)
    test_outputs_prog.append(output_data)
    
    
train_inputs_prog = np.concatenate(train_inputs_prog, axis=0)
train_inputs_add_prog = np.concatenate(train_inputs_add_prog, axis=0)
train_outputs_prog = np.concatenate(train_outputs_prog, axis=0)
test_inputs_prog = np.concatenate(test_inputs_prog, axis=0)
test_inputs_add_prog = np.concatenate(test_inputs_add_prog, axis=0)
test_outputs_prog = np.concatenate(test_outputs_prog, axis=0)

#%% prognosis decoder

class Decoder_Prog(nn.Module):
    def __init__(self, num_points=25):
        super(Decoder_Prog, self).__init__()
        self.num_points = num_points  # 
        self.fc1 = nn.Linear(6, 128)
        self.fc2 = nn.Sequential(nn.Linear(128, 256),
                                  nn.Linear(256, 256),
                                 nn.Linear(256, 128),
                                 
                                 )
        self.fcC = nn.Linear(128, num_points)  # 
        self.fcQ = nn.Linear(128, num_points)  #
        self.fcL = nn.Linear(128, 1)  #
    
    def forward(self, input_add):
        # print(x.shape)
        # print(input_add.shape)
        # x = torch.cat((x, input_add), dim=-1)
        # print(x.shape)
        x = torch.relu(self.fc1(input_add))
        x = torch.relu(self.fc2(x))
        predict_EFC= self.fcC(x)
        predict_cap = self.fcQ(x)
        prognostic_curve = torch.stack((predict_cap, predict_EFC), dim=-1)
        predict_life = self.fcL(x)
        return prognostic_curve, predict_life


def regression_loss_prog(predicted_curve, true_curve, predict_life):
    predicted_diff = predicted_curve[:, 1:,0] - predicted_curve[:, :-1,0]
    predicted_diff1 = predicted_curve[:, 1:,1] - predicted_curve[:, :-1,1]
    monotonic_constraint = torch.mean(torch.relu(predicted_diff))+torch.mean(torch.relu(-predicted_diff1))
    
    return 1.0*nn.MSELoss()(predicted_curve[:,:,0], true_curve[:,:,0])+\
        1.0*nn.MSELoss()(predicted_curve[:,:,1], true_curve[:,:,1])+\
            0.1*nn.MSELoss()(predicted_curve[:,-1,1], true_curve[:,-1,1])+\
                0.1*nn.MSELoss()(predicted_curve[:,-1,0], true_curve[:,-1,0])+\
                    1*nn.MSELoss()(predict_life.reshape(-1), true_curve[:,-1,1].reshape(-1))+\
                        0.1*torch.mean(torch.relu(-predicted_curve[:,:,1]))+\
                            0.1*torch.mean(torch.relu(0.0-predicted_curve[:,:,0]))+\
                                0.1*monotonic_constraint


def error_evaluation(prediction, real_value):
    RMSE = metrics.mean_squared_error(prediction, real_value)**0.5
    MAE = metrics.mean_absolute_error(prediction, real_value)
    print("RMSE, MAE", RMSE, MAE)
    return RMSE, MAE


#%% data split for health progosis
set_random_seed(123)
train_inputs_prog_tensor = torch.tensor(train_inputs_prog, dtype=torch.float32).to(device)
train_inputs_add_prog_tensor = torch.tensor(train_inputs_add_prog, dtype=torch.float32).to(device)
train_outputs_prog_tensor = torch.tensor(train_outputs_prog, dtype=torch.float32).to(device)  #
test_inputs_prog_tensor = torch.tensor(test_inputs_prog, dtype=torch.float32).to(device)
test_inputs_add_prog_tensor = torch.tensor(test_inputs_add_prog, dtype=torch.float32).to(device)
test_outputs_prog_tensor = torch.tensor(test_outputs_prog, dtype=torch.float32).to(device)  # 

num_samples = train_inputs_prog_tensor.shape[0]
shuffled_indices = torch.randperm(num_samples)
train_inputs_prog_shuffled = train_inputs_prog_tensor[shuffled_indices]
train_inputs_add_prog_shuffled = train_inputs_add_prog_tensor[shuffled_indices]
train_outputs_prog_shuffled = train_outputs_prog_tensor[shuffled_indices]


val_split = 0.2  # 
num_train_samples = int((1 - val_split) * num_samples)
train_inputs_prog_train = train_inputs_prog_shuffled[:num_train_samples]
train_inputs_add_prog_train = train_inputs_add_prog_shuffled[:num_train_samples]
train_outputs_prog_train = train_outputs_prog_shuffled[:num_train_samples]
train_inputs_prog_val = train_inputs_prog_shuffled[num_train_samples:]
train_inputs_add_prog_val = train_inputs_add_prog_shuffled[num_train_samples:]
train_outputs_prog_val = train_outputs_prog_shuffled[num_train_samples:]

#%%  train health prognosis model

set_random_seed(123)

patience = 5000 # 
no_improvement = 0  # 

num_points = 25  # 
prog_model = Decoder_Prog( num_points=num_points).to(device)  # 
print(prog_model)
total_params = sum(p.numel() for p in prog_model.parameters())
print(f"Total parameters: {total_params}")

optimizer = torch.optim.AdamW(prog_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.98, patience=100)

best_loss = float('inf')  # 
best_model_path_prog = 'best_model_prog.pth' 
max_epoch = 50000
total_start_time = time.time()
for epoch in range(max_epoch):
  
    prog_model.train()
    
    prognostic_curve, predict_life = prog_model(train_inputs_add_prog_train)
    reg_loss_prog = regression_loss_prog(prognostic_curve, train_outputs_prog_train,predict_life) #, 
    
    # Backward pass and optimization
    optimizer.zero_grad()
    reg_loss_prog.backward()
    torch.nn.utils.clip_grad_norm_(prog_model.parameters(), max_norm=1.0)
    
    
    # Validation phase
    prog_model.eval()
    with torch.no_grad():
        prognostic_curve_val, predict_life_val= prog_model(train_inputs_add_prog_val)
        val_reg_loss_prog= regression_loss_prog(prognostic_curve_val, train_outputs_prog_val,predict_life_val) #, 
        
    optimizer.step()
    scheduler.step(val_reg_loss_prog)
    # Early stopping logic with best validation loss tracking
    if val_reg_loss_prog < best_loss:
        # time.sleep(0.5)
        best_loss = val_reg_loss_prog
        no_improvement = 0  # Reset counter if improvement
        torch.save(prog_model.state_dict(), best_model_path_prog)  # Save the best model
        print(f'Epoch {epoch}, '
              f'Training Loss: {reg_loss_prog.item()},'
              f'Validation Loss: {val_reg_loss_prog.item()}, LR: {scheduler.optimizer.param_groups[0]["lr"]:.6f}')
    else:
        no_improvement += 1
        
    if no_improvement >= patience or scheduler.optimizer.param_groups[0]["lr"] <5e-7:
        print(f"Early stopping at epoch {epoch}, Validation loss has not improved in {patience} epochs.")
        break
    
total_end_time = time.time()
print('Time for training:', total_end_time-total_start_time)

#%% test results for health prognostics

prog_model.load_state_dict(torch.load(best_model_path_prog, weights_only=True))
prog_model.eval()  # 
with torch.no_grad():
    predicted_curve, predict_life = prog_model(test_inputs_add_prog_tensor)
    predicted_curve_train, predict_life_train = prog_model(train_inputs_add_prog_train)
    predicted_curve_val, predict_life_val = prog_model(train_inputs_add_prog_val)

predicted_curve = predicted_curve.cpu().numpy()
predict_life = predict_life.cpu().numpy()
true_curve_test = test_outputs_prog_tensor.cpu().numpy()

predicted_curve_train = predicted_curve_train.cpu().numpy()
true_curve_train = train_outputs_prog_train.cpu().numpy()
predict_life_train = predict_life_train.cpu().numpy()

predicted_curve_val = predicted_curve_val.cpu().numpy()
true_curve_val = train_outputs_prog_val.cpu().numpy()
predict_life_val = predict_life_val.cpu().numpy()

fig, ax = plt.subplots(figsize=(6.5/2.54, 5.5/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter = plt.scatter(true_curve_train[:,-1,1]*500, predicted_curve_train[:,-1,1]*500, alpha=1, color='#8da0cd', marker='s', linewidth=0.0, s=10, edgecolors=None, label='Train')
scatter = plt.scatter(true_curve_val[:,-1,1]*500, predicted_curve_val[:,-1,1]*500, alpha=1, color='#fb8d62', marker='^', linewidth=0.0, s=10, edgecolors=None, label='Validation')
scatter = plt.scatter(true_curve_test[:,-1,1]*500, predicted_curve[:,-1,1]*500, alpha=1, color='#87d5c6',marker='o', linewidth=0.0, s=10, edgecolors=None, label='Test')
plt.plot(np.linspace(0,1200),np.linspace(0,1200),'-',color='dimgray')
plt.xlabel('Real',labelpad=1)
plt.ylabel('Predictions',labelpad=1)
plt.ylim([0,1200])
plt.xlim([0,1200])
ax.tick_params(axis='both', which='both', length=0)
# plt.legend(frameon=False,labelspacing=0.01, fontsize=10, loc='upper left', handletextpad=-0.5)
plt.legend(frameon=False, labelspacing=0.01, fontsize=10, 
           handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 1.0))
ax_inset = inset_axes(ax, width="30%", height="25%", bbox_to_anchor=(-0.06, -0.6, 1.05, 1), bbox_transform=ax.transAxes)  
errors = predicted_curve[:,-1,1]*500 - true_curve_test[:,-1,1]*500  
ax_inset.hist(errors, bins=25, color='dimgray', alpha=0.7) 
# ax_inset.set_xlim([-550, 550])
ax_inset.tick_params(axis='both', which='major', bottom=False, top=False, left=False, right=False, labelsize=10)  
plt.show()

fig, ax = plt.subplots(figsize=(6.5/2.54, 5.5/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter = plt.scatter(true_curve_train[:,-1,1]*500, predict_life_train*500, alpha=1,color='#8da0cd', marker='s', linewidth=0.0, s=10, edgecolors=None, label='Train')
scatter = plt.scatter(true_curve_val[:,-1,1]*500, predict_life_val*500, alpha=1, color='#fb8d62',marker='^', linewidth=0.0, s=10, edgecolors=None, label='Validation')
scatter = plt.scatter(true_curve_test[:,-1,1]*500, predict_life*500, alpha=1, color='#87d5c6',marker='o', linewidth=0.0, s=10, edgecolors=None, label='Test')
plt.plot(np.linspace(0,1200),np.linspace(0,1200),'-',color='dimgray')
plt.xlabel('Real',labelpad=1)
plt.ylabel('Predictions',labelpad=1)
plt.ylim([0,1200])
plt.xlim([0,1200])
ax.tick_params(axis='both', which='both', length=0)
plt.legend(frameon=False, labelspacing=0.01, fontsize=10, 
           handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 1.0))
ax_inset = inset_axes(ax, width="30%", height="25%", bbox_to_anchor=(-0.06, -0.6, 1.05, 1), bbox_transform=ax.transAxes) 
errors = predict_life.reshape(-1)*500 - true_curve_test[:,-1,1]*500  
ax_inset.hist(errors, bins=25, color='dimgray', alpha=0.7) 
# ax_inset.set_xlim([-550, 550])
ax_inset.tick_params(axis='both', which='major', bottom=False, top=False, left=False, right=False, labelsize=10)  
plt.show()


RMSE_predicted,MAE_predicted = error_evaluation(500*true_curve_test[:,-1,1].reshape(-1),500*predicted_curve[:,-1,1].reshape(-1))
residuals = 500*predicted_curve[:,-1,1].reshape(-1)-500*true_curve_test[:,-1,1].reshape(-1)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((500*true_curve_test[:,-1,1].reshape(-1) - np.mean(500*true_curve_test[:,-1,1].reshape(-1)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)

RMSE_predicted,MAE_predicted = error_evaluation(500*true_curve_test[:,-1,1].reshape(-1),500*predict_life.reshape(-1))
residuals = 500*predict_life.reshape(-1)-500*true_curve_test[:,-1,1].reshape(-1)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((500*true_curve_test[:,-1,1].reshape(-1) - np.mean(500*true_curve_test[:,-1,1].reshape(-1)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)

all_curve_predict = np.vstack((predicted_curve, predicted_curve_train,predicted_curve_val))
all_curve_true = np.vstack((true_curve_test, true_curve_train,true_curve_val))
RMSE_predicted,MAE_predicted = error_evaluation(500*all_curve_true[:,-1,1].reshape(-1),500*all_curve_predict[:,-1,1].reshape(-1))
residuals = 500*all_curve_predict[:,-1,1].reshape(-1)-500*all_curve_true[:,-1,1].reshape(-1)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((500*all_curve_true[:,-1,1].reshape(-1) - np.mean(500*all_curve_true[:,-1,1].reshape(-1)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


RMSE_predicted,MAE_predicted = error_evaluation(500*true_curve_test[:,:,1].reshape(-1),500*predicted_curve[:,:,1].reshape(-1))
residuals = 500*predicted_curve[:,:,1].reshape(-1)-500*true_curve_test[:,:,1].reshape(-1)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((500*true_curve_test[:,:,1].reshape(-1) - np.mean(500*true_curve_test[:,:,1].reshape(-1)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)

RMSE_predicted,MAE_predicted = error_evaluation(nominal_capacity*true_curve_test[:,:,0].reshape(-1),nominal_capacity*predicted_curve[:,:,0].reshape(-1))
residuals = nominal_capacity*predicted_curve[:,:,0].reshape(-1)-nominal_capacity*true_curve_test[:,:,0].reshape(-1)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((nominal_capacity*true_curve_test[:,:,0].reshape(-1) - np.mean(nominal_capacity*true_curve_test[:,:,0].reshape(-1)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


fig, ax = plt.subplots(figsize=(6.5/2.54, 5.5/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
for i in range(0, len(true_curve_test), 1):  # 
    
    plt.plot(true_curve_test[i, :, 1]*500, true_curve_test[i, :, 0]*nominal_capacity, color='#8da0cd', linestyle='-')
    # plt.plot(predicted_curve[i, :,1]*500, predicted_curve[i, :,0]*nominal_capacity, label='Prediction', color='#E59693', linestyle='--')

for i in range(0, len(true_curve_test), 1):  # 
    # plt.plot(true_curve_test[i, :, 1]*500, true_curve_test[i, :, 0]*nominal_capacity, label='Ground truth', color='#0073B1', linestyle='-')
    plt.plot(predicted_curve[i, :,1]*500, predicted_curve[i, :,0]*nominal_capacity, color='#87d5c6', alpha=0.9,linestyle='--')
# plt.plot(true_curve_test[i, :, 1]*500, true_curve_test[i, :, 0]*nominal_capacity, label='Ground truth', color='#0073B1', linestyle='-')
# plt.plot(predicted_curve[i, :,1]*500, predicted_curve[i, :,0]*nominal_capacity, label='Prediction', color='#E59693', alpha=0.9,linestyle='--')
plt.xlabel('EFCs')
plt.ylabel('Capacity [Ah]')
# plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
# plt.legend(frameon=False,labelspacing=0.01)
# plt.grid(True)
plt.show()

#%% shap analysis for prognosis decoder

class DecoderModel(torch.nn.Module):
    def __init__(self, decoder_prog):
        super(DecoderModel, self).__init__()
        self.decoder_prog = decoder_prog
    
    def forward(self, inputs):
        # 
        
        prognostic_curve, predict_life = self.decoder_prog(inputs)
        # 
        Q_curve_mean = torch.mean(prognostic_curve[:, :, 0], dim=1)
        C_curve_mean = torch.mean(prognostic_curve[:, :, 1], dim=1)
        L_curve_mean = prognostic_curve[:, -1, 1]
        # 
        combined_output = torch.cat((Q_curve_mean.unsqueeze(-1), C_curve_mean.unsqueeze(-1), L_curve_mean.unsqueeze(-1)), dim=-1)
        return combined_output


decoder_model = DecoderModel(prog_model).to(device)

decoder_explainer = shap.DeepExplainer(decoder_model, test_inputs_add_prog_tensor)
decoder_shap_values = decoder_explainer.shap_values(test_inputs_add_prog_tensor, check_additivity=False)

Q_curve_shap_values = decoder_shap_values[0] #[:,:,0]
C_curve_shap_values = decoder_shap_values[1]#[:,:,1]
L_curve_shap_values = decoder_shap_values[2] #[:,:,2]

feature_names = ["Cp", "Cn", "x0", "y0", "Q", "EFC"]


cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#8da0cd','#87d5c6'])
plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
shap.summary_plot(Q_curve_shap_values, 
                  features=test_inputs_add_prog_tensor.cpu().detach().numpy(), 
                  feature_names=feature_names,
                  max_display=6,plot_size=(9/2.54, 5.5/2.54),
                  cmap=cmap)  #
plt.show()

fig, ax = plt.subplots(figsize=(10/2.54, 5/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
shap.summary_plot(C_curve_shap_values, 
                  features=test_inputs_add_prog_tensor.cpu().detach().numpy(), 
                  feature_names=feature_names,
                  max_display=6,plot_size=(9/2.54, 5.5/2.54),
                  cmap=cmap)  #
plt.show()

plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
shap.summary_plot(L_curve_shap_values, 
                  features=test_inputs_add_prog_tensor.cpu().detach().numpy(), 
                  feature_names=feature_names,
                  max_display=6,plot_size=(9/2.54, 5.5/2.54),
                  cmap=cmap)  #
plt.show()


class DecoderModel2(torch.nn.Module):
    def __init__(self, decoder_prog):
        super(DecoderModel2, self).__init__()
        self.decoder_prog = decoder_prog
    
    def forward(self, inputs):
        
        prognostic_curve, predict_life = self.decoder_prog(inputs)
       
        return predict_life


decoder_model2 = DecoderModel2(prog_model).to(device)

decoder_explainer2 = shap.DeepExplainer(decoder_model2, test_inputs_add_prog_tensor)
decoder_shap_values2 = decoder_explainer2.shap_values(test_inputs_add_prog_tensor, check_additivity=False)
decoder_shap_values2 = decoder_shap_values2#[:,:,0]
plt.figure(num=None,figsize=(10/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
shap.summary_plot(decoder_shap_values2, 
                  features=test_inputs_add_prog_tensor.cpu().detach().numpy(), 
                  feature_names=feature_names,
                  # plot_type='bar', 
                  max_display=6,plot_size=(9/2.54, 5.5/2.54),
                  cmap=cmap)  #
plt.show()


#%
class DecoderModelPerPoint(torch.nn.Module):
    def __init__(self, decoder_prog, time_step):
        """
        """
        super(DecoderModelPerPoint, self).__init__()
        self.decoder_prog = decoder_prog
        self.time_step = time_step  
    
    def forward(self, inputs):
       
        prognostic_curve, _ = self.decoder_prog(inputs)
        
        combined_output = torch.cat((prognostic_curve[:, self.time_step, 0].unsqueeze(-1), prognostic_curve[:, self.time_step, 1].unsqueeze(-1)), dim=-1)
        return combined_output  

def calculate_shap_for_each_point(prog_model, test_inputs_add_prog_tensor):
    shap_values_per_point = []
    
    for time_step in range(25):
        print(time_step)
        point_model = DecoderModelPerPoint(prog_model, time_step).to(device)
        decoder_explainer = shap.DeepExplainer(point_model, test_inputs_add_prog_tensor)
        shap_values_t = decoder_explainer.shap_values(test_inputs_add_prog_tensor, check_additivity=False)
        shap_values_per_point.append(shap_values_t)

    shap_values_per_point = np.stack(shap_values_per_point, axis=1)  # [batch_size, time_points, features]
    return shap_values_per_point

shap_values_per_point = calculate_shap_for_each_point(prog_model, test_inputs_add_prog_tensor)


shap_values_mean_abs_0 = np.mean(np.abs(shap_values_per_point[0,:,:,:]), axis=1)  

# For prognostic_curve[:,:,1]
shap_values_mean_abs_1 = np.mean(np.abs(shap_values_per_point[1,:,:,:]), axis=1) 


feature_names = ["Cp", "Cn", "x0", "y0", "Q", "EFC"]
plt.figure(num=None,figsize=(10.5/2.54,4/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.imshow(shap_values_mean_abs_0.T, aspect='auto', cmap=cmap, interpolation='nearest')
plt.yticks(ticks=np.arange(len(feature_names)), labels=feature_names)
plt.colorbar()
plt.title('Mean Absolute SHAP Values (Capacity)')
# plt.xlabel('Time Points')
plt.xticks([]) 
# plt.ylabel('Features')
plt.tight_layout()
plt.show()


plt.figure(num=None,figsize=(10/2.54,4/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.imshow(shap_values_mean_abs_1.T, aspect='auto', cmap=cmap, interpolation='nearest')
# plt.colorbar(label='Mean Absolute SHAP Value')
# plt.xticks(ticks=np.arange(25), labels=[f'Point {i+1}' for i in range(25)], rotation=90)
plt.yticks(ticks=np.arange(len(feature_names)), labels=feature_names)
plt.colorbar()
plt.title('Mean Absolute SHAP Values (EFC)')
# plt.xlabel('Time Points')
plt.xticks([]) 
# plt.ylabel('Features')
plt.tight_layout()
plt.show()
