# -*- coding: utf-8 -*-
"""
"""

from beep import structure
import os
import numpy as np
from numpy import gradient
import torch
import torch.nn as nn
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from sklearn import metrics
import matplotlib.colors as mcolors
import time
import shap
from pyswarm import pso

from sklearn.model_selection import KFold

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})

#%% Load half cell data and fitting function construction
folder_loc = '..\\006_half_cell'
file_name = 'anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv'
# file_path = f'{folder_loc}\\{file_name}'
file_path = f'{file_name}'
OCPn_data = pd.read_csv(file_path)

file_name = 'cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv'
# file_path = f'{folder_loc}\\{file_name}'
file_path = f'{file_name}'
OCPp_data = pd.read_csv(file_path)

OCPn_SOC = torch.tensor(OCPn_data['SOC_linspace'].values, dtype=torch.float32).to('cuda')  # 
OCPn_V = torch.tensor(OCPn_data['Voltage'].values, dtype=torch.float32).to('cuda')
OCPp_SOC = torch.tensor(OCPp_data['SOC_linspace'].values, dtype=torch.float32).to('cuda')
OCPp_V = torch.tensor(OCPp_data['Voltage'].values, dtype=torch.float32).to('cuda')


def torch_interpolate(x, xp, fp):
    indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1)
    x0 = xp[indices - 1]
    x1 = xp[indices]
    f0 = fp[indices - 1]
    f1 = fp[indices]
    return f0 + (f1 - f0) * (x - x0) / (x1 - x0)

def OCP_p(SOC_p):
    SOC_p = SOC_p.to(OCPp_SOC.device)  
    return torch_interpolate(SOC_p, OCPp_SOC, OCPp_V)

def OCP_n(SOC_n):
    SOC_n = SOC_n.to(OCPn_SOC.device)  
    return torch_interpolate(SOC_n, OCPn_SOC, OCPn_V)


#%%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class EncoderMLP(nn.Module):
    def __init__(self, num_feature=3, hidden_size=128, output_size=64):
        super(EncoderMLP, self).__init__()
        input_size = 200 * num_feature  # 
        self.fc1 = nn.Linear(input_size, int(1*hidden_size))
        self.fc2 = nn.Linear(int(1*hidden_size), int(1*hidden_size))
        self.fc3 = nn.Linear(int(1*hidden_size), output_size)

        self.cp_out = nn.Linear(output_size, 1)
        self.cn_out = nn.Linear(output_size, 1)
        self.x0_out = nn.Linear(output_size, 1)
        self.y0_out = nn.Linear(output_size, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
        
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # Flatten (n, 200*num_feature)
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        
        Cp = torch.relu(self.cp_out(x)) 
        Cn = torch.relu(self.cn_out(x)) 
        x0 = torch.relu(self.x0_out(x)) 
        y0 = torch.relu(self.y0_out(x)) 
        return Cp, Cn, x0, y0


class Decoder(nn.Module):
    def __init__(self, num_points=200):
        super(Decoder, self).__init__()
        self.num_points = num_points  # 
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fcv = nn.Linear(256, num_points)  # 
        self.fcq = nn.Linear(256, num_points)  #
    
    def forward(self, Cp, Cn, x0, y0):
        
        x = torch.cat((Cp, Cn, x0, y0), dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        V_OCV_curve = self.fcv(x)  
        Q_OCV_curve = self.fcq(x) 
        OCV_Q_curve = torch.stack((V_OCV_curve, Q_OCV_curve), dim=-1)
        return OCV_Q_curve


class BatteryModel(nn.Module):
    def __init__(self,encoder_name, num_points=200):
        super(BatteryModel, self).__init__()
        if encoder_name == 'MLP':
            self.encoder = EncoderMLP()
        else:
            raise ValueError(f"Unknown encoder name: {encoder_name}")
            
        self.decoder = Decoder(num_points=num_points)

    def forward(self, x):
        
        Cp, Cn, x0, y0= self.encoder(x)
        OCV_Q_curve = self.decoder(Cp, Cn, x0, y0)
        # Now calculate the OCV using the equation
        SOC_p = y0 - (OCV_Q_curve[:,:,1] / Cp)  # Ensure no division by zero
        SOC_n = x0 - (OCV_Q_curve[:,:,1]  / Cn)
        Calculated_OCV = OCP_p(SOC_p) - OCP_n(SOC_n)
        Calculated_OCV = Calculated_OCV/4.2

        return OCV_Q_curve, Cp, Cn, x0, y0, Calculated_OCV, SOC_p, SOC_n#, Qt


def regression_loss(predicted_ocv_Q_curve, true_ocv_Q_curve):
    return nn.MSELoss()(predicted_ocv_Q_curve[:,:,0], true_ocv_Q_curve[:,:,0])+\
        nn.MSELoss()(predicted_ocv_Q_curve[:,:,1], true_ocv_Q_curve[:,:,1])

def physics_loss( Cp, Cn, x0, y0, predicted_ocv_Q_curve,true_ocv_Q_curve, eps=1e-3):
    # print(y0)
    SOC_p = y0 - (predicted_ocv_Q_curve[:,:,1] / Cp)  # Ensure no division by zero
    SOC_n = x0 - (predicted_ocv_Q_curve[:,:,1] / Cn)
    # print(SOC_p)

    
    Calculated_OCV = OCP_p(SOC_p) - OCP_n(SOC_n)
    Calculated_OCV = Calculated_OCV/4.2

    predict_range_loss = (
        torch.mean(torch.relu(Calculated_OCV[:, :] - 1))+
        torch.mean(torch.relu(2.7/4.2-Calculated_OCV[:, :]))+
        torch.mean(torch.relu(predicted_ocv_Q_curve[:,:,0] - 1))+
        torch.mean(torch.relu(2.7/4.2-predicted_ocv_Q_curve[:,:,0]))+
        torch.mean(torch.relu(-predicted_ocv_Q_curve[:, :, 1]))+
        torch.mean(torch.relu(predicted_ocv_Q_curve[:, 0, 1]))
        )
        

    soc_p_constraint = torch.mean(torch.relu(SOC_p - 1.0) + torch.relu(0 - SOC_p))
    soc_n_constraint = torch.mean(torch.relu(SOC_n - 1.0) + torch.relu(0 - SOC_n))
    # Cp and Cn in reasonable ranges
    cp_constraint = torch.mean(torch.relu(Cp - 1.1) + torch.relu(0.1 - Cp)) #torch.mean(torch.relu(-Cp))
    cn_constraint = torch.mean(torch.relu(Cn - 1.1) + torch.relu(0.1 - Cn)) #torch.mean(torch.relu(-Cn))
    ocv_loss = nn.MSELoss()(Calculated_OCV, predicted_ocv_Q_curve[:,:,0])
    # Total physical loss
    total_physics_loss = (
        ocv_loss 
    
    )
    
    total_constrain_loss = (

        + cp_constraint 
        + cn_constraint 
        + predict_range_loss
        + soc_p_constraint 
        + soc_n_constraint  # Uncomment if SOC constraints are also desired
        
    )
    
    return total_physics_loss, total_constrain_loss


def total_loss(predicted_ocv_Q_curve, true_ocv_Q_curve, Cp, Cn, x0, y0, phys_weight, reg_weight, constraint_weight): #, reg_weight=1.0, phys_weight=1.0
    reg_loss = regression_loss(predicted_ocv_Q_curve, true_ocv_Q_curve)
    phys_loss, constraint_loss = physics_loss(Cp, Cn, x0, y0, predicted_ocv_Q_curve,true_ocv_Q_curve)
    total_loss_value = reg_weight * reg_loss + phys_weight * phys_loss + constraint_weight * constraint_loss
    return total_loss_value, reg_loss, phys_loss, constraint_loss

def error_evaluation(prediction, real_value):
    RMSE = metrics.mean_squared_error(prediction, real_value)**0.5
    MAE = metrics.mean_absolute_error(prediction, real_value)
    print("RMSE, MAE", RMSE, MAE)
    return RMSE, MAE

#%%

def read_resval_data(file_name, idx_seed,  predic_c_rate, num_points=200):
    
    inputs = []
    outputs = []
    nominal_capacity = 4.84
    datapath = structure.MaccorDatapath.from_file(file_name)
    all_cycle_types = [
            "start_discharge",
            "C/80_Cycle",
            "GITT",
            "C/40_Cycle",
            "0.05A_Cycle_mistake",
            "C/10_Cycle",
            "C/7_cycle",
            "C/5_Cycle",
            "1C_Cycle",
            "2C_Cycle",
            "charge_for_storage",
    ]
    datapath.raw_data["cycle_type"] = datapath.raw_data["cycle_index"].apply(lambda x: all_cycle_types[x])
    
    data=datapath.raw_data
    
    
    custom_legends = ["C/80", "C/40", "0.05 A", "C/10", "C/7", "C/5", "1C", "2C"]
    discharge_data_filter = data[data["discharge_capacity"] > 0]
    
    # fig, ax1 = plt.subplots(1, 1, figsize=(6.5 / 2.54, 5 / 2.54), dpi=600,
    #                         gridspec_kw={'hspace': 0})  # 'hspace=0' 
    # colors = ['#F1766D', '#f3a17c', '#ecc68c', '#b3c6bb', '#78A040', '#839DD1', '#27B2AF', '#752a80']
    # linestyles = ['-', '--', ':', '-.', ':', '--', '-', '-']
    
    # lines = []
    # for i, cycle_type in enumerate(["C/80_Cycle", "C/40_Cycle", "0.05A_Cycle_mistake", "C/10_Cycle",
    #                                 "C/7_cycle", "C/5_Cycle", "1C_Cycle", "2C_Cycle"]):
    #     if cycle_type in discharge_data_filter["cycle_type"].values:
    #         line, = ax1.plot(
    #             discharge_data_filter[discharge_data_filter["cycle_type"] == cycle_type]['discharge_capacity'],
    #             discharge_data_filter[discharge_data_filter["cycle_type"] == cycle_type]['voltage'],
    #             linestyle=linestyles[i], color=colors[i], label=custom_legends[i]
    #         )
    #         lines.append(line)
    
    # ax1.set_ylabel('Voltage [V]')
    # ax1.set_xlabel('Capacity [Ah]')
    # ax1.tick_params(top=True, right=True, which='both', direction='in')
    # ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    # # plt.title()
    # legend1 = ax1.legend(handles=lines[:3], labels=custom_legends[:3], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=0.2, loc='upper right')
    # legend2 = ax1.legend(handles=lines[3:], labels=custom_legends[3:], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=0.2, loc='lower left')
    # plt.gca().add_artist(legend1)  
    
    # plt.show()
    
    for round_idx in range(1):
        set_random_seed(idx_seed+round_idx)
        filtered_data = data[data["cycle_type"] == "1C_Cycle"]
        if len(filtered_data)==0:
            continue
        V = np.array(filtered_data['voltage'])
        # print(len(V))
        I = np.array(filtered_data['current'])
        Q = np.array(filtered_data['discharge_capacity'])
        t = np.array(filtered_data['step_time'])
        
        for charge_end_idx in range(2,len(Q)):
            if Q[charge_end_idx]>Q[charge_end_idx-1] and I[charge_end_idx]<0:
                break

        for discharge_end_idx in range(2,len(Q)):
            if V[discharge_end_idx] <= 2.7:
                break
        
        start_index = charge_end_idx#random.randint(charge_end_idx, charge_end_idx + int(0.3 * (discharge_end_idx- charge_end_idx)))
        
        if np.max(Q[charge_end_idx:discharge_end_idx])<2 :
            continue
        for end_index in range(start_index, discharge_end_idx):
            if (t[end_index]-t[start_index])>1200:
                # print((Q[end_index]-Q[start_index]))
                break
        # end_index = discharge_end_idx
        # end_index = random.randint(charge_end_idx+int(0.7 * (len(Q)-charge_end_idx)), len(Q))
        
        V = V[start_index:end_index] / 4.2
        # V = V[::-1]
        Q = Q[start_index:end_index] / nominal_capacity-Q[start_index] / nominal_capacity
    
        x_new = np.linspace(0, 1, num_points)
        interp_voltage = interp1d(np.linspace(0, 1, len(V)), V, kind='linear')
        interp_capacity = interp1d(np.linspace(0, 1, len(Q)), Q, kind='linear')
        
        charge_voltage_interp = interp_voltage(x_new)
        charge_capacity_interp = interp_capacity(x_new)
        charge_EFCs_interp = np.ones(200)
        
        #% C/40
        
        filtered_data = data[data["cycle_type"] == predic_c_rate]
        V = np.array(filtered_data['voltage'])
        I = np.array(filtered_data['current'])
        Q = np.array(filtered_data['discharge_capacity'])
        t = np.array(filtered_data['step_time'])
        
        for charge_end_idx in range(2,len(Q)):
            if Q[charge_end_idx]>Q[charge_end_idx-1] and I[charge_end_idx]<0:
                break
        for discharge_end_idx in range(2,len(Q)):
            if V[discharge_end_idx] <= 2.7:
                break
        start_index = charge_end_idx     
        V = V[start_index:discharge_end_idx+1] / 4.2
        Q = Q[start_index:discharge_end_idx+1] / nominal_capacity-Q[start_index] / nominal_capacity
        
        x_new = np.linspace(0, 1, num_points)
        
        interp_ocv_v = interp1d(np.linspace(0, 1, len(V)), V, kind='linear')
        interp_ocv_q = interp1d(np.linspace(0, 1, len(Q)), Q, kind='linear')
        discharge_voltage_interp = interp_ocv_v(x_new)
        discharge_capacity_interp = interp_ocv_q(x_new)
        
        
        input_data = np.stack([charge_voltage_interp, charge_capacity_interp, charge_EFCs_interp], axis=-1) #
        output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
    
        # Add the processed sample to the list
        inputs.append(input_data)
        outputs.append(output_data)    
    
    # Convert input and output lists to NumPy arrays
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)
    
    return inputs, outputs


predict_rate_list=["C/80_Cycle",
                "C/40_Cycle",
                "0.05A_Cycle_mistake",
                "C/10_Cycle",
                "C/7_cycle",
                "C/5_Cycle",
]

predict_ocv_all_diff_rate = []
Cp_all_diff_rate = []
Cn_all_diff_rate = []
x0_all_diff_rate = []
y0_all_diff_rate = []
shap_all_Q_diff_rate = []
shap_all_V_diff_rate = []

calculated_ocv_all_diff_rate = []
true_ocv_all_diff_rate = []
C_rates_all_diff_rate = []
measure_all_cap_rate = []
for predic_c_rate in predict_rate_list:
    print(predic_c_rate)
    folder_loc = '..\\ResValData'
    folder_loc = os.path.abspath(folder_loc)
    file_list = [f for f in os.listdir(folder_loc) if f.startswith('ResVal')]
    nominal_capacity = 4.84
    
    all_inputs, all_outputs = [], []
    for i in range(0,len(file_list)): # train_batteries all_batteries 
        battery = file_list[i]
        if battery =='ResVal_000084_0000ED.072':
            continue
        file_name = os.path.join(folder_loc, battery)
        # print("Loading cell", battery)
        cap = []
        input_data, output_data = read_resval_data(file_name,i,predic_c_rate)
        if len(input_data)==0:
            print('skip')
            continue
        all_inputs.append(input_data)
        all_outputs.append(output_data)
    all_inputs = np.concatenate(all_inputs, axis=0)
    all_outputs = np.concatenate(all_outputs, axis=0)    
    plt.figure(num=None,figsize=(6/2.54,5/2.54),dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(all_outputs[:,-1,1] ,all_inputs[:,-1,1], 'o')
    # plt.plot(Cap_real, Cap_real)
    plt.xlabel('Real')
    plt.ylabel('Measurement')
    plt.show()
    
    
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    predict_ocv_all = []
    Cp_all = []
    Cn_all = []
    x0_all = []
    y0_all = []
    calculated_ocv_all = []
    true_ocv_all = []
    C_rate_all = []
    prior_cap_all = []
    shap_all_Q = []
    shap_all_V = []
    for fold, (train_index, test_index) in enumerate(kf.split(all_inputs)):
        print(f"Fold {fold + 1}")
        train_inputs, test_inputs = all_inputs[train_index], all_inputs[test_index]
        train_outputs, test_outputs = all_outputs[train_index], all_outputs[test_index]
        
        set_random_seed(123)
        train_inputs_tensor = torch.tensor(train_inputs, dtype=torch.float32).to(device)
        train_outputs_tensor = torch.tensor(train_outputs, dtype=torch.float32).to(device)  #
        test_inputs_tensor = torch.tensor(test_inputs, dtype=torch.float32).to(device)
        test_outputs_tensor = torch.tensor(test_outputs, dtype=torch.float32).to(device)  # 
        
        num_samples = train_inputs_tensor.shape[0]
        shuffled_indices = torch.randperm(num_samples)
        train_inputs_shuffled = train_inputs_tensor[shuffled_indices]
        train_outputs_shuffled = train_outputs_tensor[shuffled_indices]
        
        
        val_split = 0.2  # 
        num_train_samples = int((1 - val_split) * num_samples)
        
        train_inputs_train = train_inputs_shuffled[:num_train_samples]
        train_outputs_train = train_outputs_shuffled[:num_train_samples]
        train_inputs_val = train_inputs_shuffled[num_train_samples:]
        train_outputs_val = train_outputs_shuffled[num_train_samples:]
    
        #%
        set_random_seed(123)
        
        patience = 2000  # 
        no_improvement = 0  # 
        
        num_points = 200  # 
        model = BatteryModel(encoder_name = 'MLP', num_points=num_points).to(device)  # 
        # # print(model)
        # total_params = sum(p.numel() for p in model.parameters())
        # print(f"Total parameters: {total_params}")
        model.load_state_dict(torch.load('best_model.pth', weights_only=True))
        
        # for param in model.encoder.fc1.parameters():
        #     param.requires_grad = False
        # # for param in model.encoder.fc2.parameters():
        # #     param.requires_grad = False
        # for param in model.decoder.fc1.parameters():
        #     param.requires_grad = False
        # total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        # print(f"Total trainable parameters after freezing: {total_params}")
    
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=100)
        
        best_loss = float('inf')  # 
        best_model_path = 'best_model_fine_tune.pth' 
        max_epoch = 20000
        phys_weight = 1
        reg_weight = 1
        constraint_weight = 0.1
        total_start_time = time.time()
        
        for epoch in range(max_epoch):
          
            model.train()
            predicted_Q_ocv_curve, Cp, Cn, x0, y0,Calculated_Q_ocv, SOC_p, SOC_n = model(train_inputs_train)
            total_loss_value, reg_loss, phys_loss,constraint_loss = total_loss(predicted_Q_ocv_curve, train_outputs_train,
                                                               Cp, Cn, x0, y0, phys_weight, reg_weight,constraint_weight) #, 
            
            # Backward pass and optimization
            optimizer.zero_grad()
            total_loss_value.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            
            # Validation phase
            model.eval()
            with torch.no_grad():
                predicted_val_Q_ocv, Cp, Cn, x0, y0,Calculated_val_Q_ocv, SOC_p, SOC_n = model(train_inputs_val)
                val_total_loss_value, val_reg_loss, val_phys_loss,val_constraint_loss = total_loss(predicted_val_Q_ocv, train_outputs_val,
                                                                               Cp, Cn, x0, y0, phys_weight, reg_weight,constraint_weight) #, reg_weight=1.0, physics_weight=5
            optimizer.step()
            scheduler.step(val_total_loss_value)
            # Early stopping logic with best validation loss tracking
            if val_total_loss_value < best_loss:
                # time.sleep(0.5)
                best_loss = val_total_loss_value
                no_improvement = 0  # Reset counter if improvement
                torch.save(model.state_dict(), best_model_path)  # Save the best model
               
            else:
                no_improvement += 1
                
            
            # Early stopping based on patience
            if no_improvement >= patience or scheduler.optimizer.param_groups[0]["lr"] <5e-7:
                print(f"Early stopping at epoch {epoch}, Validation loss has not improved in {patience} epochs.")
                break
            
        total_end_time = time.time()
        print('Time for training:', total_end_time-total_start_time)
        
        #%
        model.load_state_dict(torch.load(best_model_path, weights_only=True))
        model.eval()  # 
        
        with torch.no_grad():
            predicted_test_ocv, Cp, Cn, x0, y0,Calculated_test_ocv, SOC_p, SOC_n = model(test_inputs_tensor)
        predicted_ocv_curve_test = predicted_test_ocv.cpu().numpy()
        Calculated_test_ocv = Calculated_test_ocv.cpu().numpy()
        true_ocv_curve_test = test_outputs_tensor.cpu().numpy()
        Cp = Cp.cpu().numpy()
        Cn = Cn.cpu().numpy()
        x0 = x0.cpu().numpy()
        y0 = y0.cpu().numpy()
        
        predict_ocv_all.append(predicted_ocv_curve_test)
        calculated_ocv_all.append(Calculated_test_ocv)
        true_ocv_all.append(true_ocv_curve_test)
        Cp_all.append(Cp)
        Cn_all.append(Cn)
        x0_all.append(x0)
        y0_all.append(y0)
        prior_cap_all.append(test_inputs_tensor.cpu().numpy()[:,-1,1])
        
        x = test_inputs_tensor.to(device)
        Cp, Cn, x0, y0= model.encoder(x)

        # 
        class DecoderModel(torch.nn.Module):
            def __init__(self, decoder):
                super(DecoderModel, self).__init__()
                self.decoder = decoder
            
            def forward(self, inputs):
                # 
                Cp, Cn, x0, y0 = torch.split(inputs, [1, 1, 1, 1], dim=-1)
                OCV_Q_curve = self.decoder(Cp, Cn, x0, y0)
                # 
                V_OCV_curve_mean = torch.mean(OCV_Q_curve[:, :, 0], dim=1)
                Q_OCV_curve_mean = torch.mean(OCV_Q_curve[:, :, 1], dim=1)
                # 
                combined_output = torch.cat((V_OCV_curve_mean.unsqueeze(-1), Q_OCV_curve_mean.unsqueeze(-1)), dim=-1)
                return combined_output


        decoder_model = DecoderModel(model.decoder).to(device)

        decoder_inputs = torch.cat((Cp, Cn, x0, y0), dim=-1)

        decoder_explainer = shap.DeepExplainer(decoder_model, decoder_inputs)
        decoder_shap_values = decoder_explainer.shap_values(decoder_inputs, check_additivity=False)

        V_OCV_curve_shap_values = decoder_shap_values[0]#[:,:,0]
        Q_OCV_curve_shap_values = decoder_shap_values[1]#[:,:,1]

        feature_names = ["Cp", "Cn", "x0", "y0"]

        font_properties = {'family': 'Times New Roman', 'size': 12}
        rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
        cmap2 = mcolors.LinearSegmentedColormap.from_list('custom_cmap',[ '#8da0cb', '#66c2a5'])
        plt.figure(num=None,figsize=(10/2.54,6/2.54),dpi=600)
        plt.ion()
        plt.rcParams['xtick.direction'] = 'in'
        plt.rcParams['ytick.direction'] = 'in'
        plt.tick_params(top='on', right='on', which='both')
        plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        shap.summary_plot(Q_OCV_curve_shap_values, 
                          features=decoder_inputs.cpu().detach().numpy(), 
                          feature_names=feature_names,
                          max_display=4,plot_size=(10/2.54, 5/2.54),
                          cmap=cmap2)  #
        plt.show()
        
        shap_all_V.append(np.mean(np.abs(V_OCV_curve_shap_values), axis=0).reshape(-1,1))
        shap_all_Q.append(np.mean(np.abs(Q_OCV_curve_shap_values), axis=0).reshape(-1,1))

        
    predict_ocv_all_diff_rate.append(np.concatenate(predict_ocv_all, axis=0))
    Cp_all_diff_rate.append(np.concatenate(Cp_all, axis=0))
    Cn_all_diff_rate.append(np.concatenate(Cn_all, axis=0))
    x0_all_diff_rate.append(np.concatenate(x0_all, axis=0))
    y0_all_diff_rate.append(np.concatenate(y0_all, axis=0))
    calculated_ocv_all_diff_rate.append(np.concatenate(calculated_ocv_all, axis=0))
    true_ocv_all_diff_rate.append(np.concatenate(true_ocv_all, axis=0))
    C_rates_all_diff_rate.append([predic_c_rate] * np.concatenate(true_ocv_all, axis=0).shape[0])
    measure_all_cap_rate.append(np.concatenate(prior_cap_all, axis=0))
    shap_all_Q_diff_rate.append(np.concatenate(shap_all_Q, axis=1))
    shap_all_V_diff_rate.append(np.concatenate(shap_all_V, axis=1))
    
    predicted_ocv_curve_test = np.concatenate(predict_ocv_all, axis=0)
    Calculated_test_ocv = np.concatenate(calculated_ocv_all, axis=0)
    true_ocv_curve_test =  np.concatenate(true_ocv_all, axis=0)
    Cp = np.concatenate(Cp_all, axis=0)
    Cn = np.concatenate(Cn_all, axis=0)
    x0 = np.concatenate(x0_all, axis=0)
    y0 = np.concatenate(y0_all, axis=0)
    Cap_real = true_ocv_curve_test[:,-1,1] 
    Cap_predict = predicted_ocv_curve_test[:,-1,1] 
    
    error_matrix = abs(predicted_ocv_curve_test[:,:,1]*nominal_capacity-true_ocv_curve_test[:,:,1]*nominal_capacity)
    error_matrix2 = abs(4.2*predicted_ocv_curve_test[:,1:,0]-4.2*true_ocv_curve_test[:,1:,0])
    # PSO optimization function
    def pso_objective_function(params):
        Cp, Cn, x0, y0 = params
        fitted_Voc = np.zeros_like(predict_Q)
        OCPp_curve = np.zeros_like(predict_Q)
        OCPn_curve = np.zeros_like(predict_Q)
        for j, Qt in enumerate(predict_Q):
            SOC_p = y0 - Qt / Cp
            SOC_n = x0 - Qt / Cn
            
            SOC_p_tensor = torch.tensor([SOC_p], dtype=torch.float32).to('cuda')
            SOC_n_tensor = torch.tensor([SOC_n], dtype=torch.float32).to('cuda')
            
            Up = OCP_p(SOC_p_tensor).item()
            Un = OCP_n(SOC_n_tensor).item()
            
            OCPp_curve[j] = Up
            OCPn_curve[j] = Un
            fitted_Voc[j] = Up - Un
    
        return np.mean((predict_V - fitted_Voc) ** 2)  # MSE
    
    
    # min_row_index = np.argmin(np.mean(error_matrix, axis=1)+ 1*np.mean(error_matrix2, axis=1))
    min_row_index = np.argmin(np.mean(error_matrix, axis=1))
    # Loop through data points for PSO optimization
    
    i = min_row_index
    
    plt.figure(num=None,figsize=(7/2.54,6/2.54),dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    # plt.plot(test_inputs_tensor[i, :, 1].cpu().numpy() * nominal_capacity, test_inputs_tensor[i, :, 0].cpu().numpy() * 4.2, 
    #          label='Measure V-Q', color='grey', linestyle='-')
    plt.plot(true_ocv_curve_test[i, :, 1]*nominal_capacity, true_ocv_curve_test[i, :, 0]*4.2, label='True OCV', color='blue', linestyle='-')
    plt.plot(predicted_ocv_curve_test[i, :,1]*nominal_capacity, predicted_ocv_curve_test[i, :,0]*4.2, label='Predicted OCV', color='red', linestyle='--')
    plt.plot(predicted_ocv_curve_test[i, :,1]*nominal_capacity, Calculated_test_ocv[i, :]*4.2, label='Derived OCV',  color='orange', linestyle='-.')
    plt.xlabel('Q [Ah]')
    plt.ylabel('V [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False,labelspacing=0.01)
    # plt.grid(True)
    plt.show()
    
    
    predict_Q = predicted_ocv_curve_test[i, :, 1]
    predict_V = predicted_ocv_curve_test[i, :, 0] * 4.2
    
    Cp_con = Cp[i].item()
    Cn_con = Cn[i].item()
    x0_con = x0[i].item()
    y0_con = y0[i].item()
    
    lb = [0.1, 0.1, 0.1, 0.1]  # Lower bounds
    ub = [1.1, 1.1, 1.0, 1.0]  # Upper bounds
    # Set PSO options, including max iterations to limit the run time
    options = {'maxiter': 20}
    optimized_params, fopt = pso(pso_objective_function, lb, ub, minstep=1e-5, minfunc=1e-5, **options)
    
    print(f"Constrianed parameters: Cp={Cp_con}, Cn={Cn_con}, x0={x0_con}, y0={y0_con}")
    print(f"PSO optimized parameters: Cp={optimized_params[0]}, Cn={optimized_params[1]}, x0={optimized_params[2]}, y0={optimized_params[3]}")
    
    Cp_opt, Cn_opt, x0_opt, y0_opt = optimized_params
    fitted_Voc = np.zeros_like(predict_Q)
    OCPp_curve_fit = np.zeros_like(predict_Q)
    OCPn_curve_fit = np.zeros_like(predict_Q)
    OCPp_curve_con = np.zeros_like(predict_Q)
    OCPn_curve_con = np.zeros_like(predict_Q)
    Con_Voc = np.zeros_like(predict_Q)
    for j, Qt in enumerate(predict_Q):
        SOC_p_fit = y0_opt - Qt / Cp_opt
        SOC_n_fit = x0_opt - Qt / Cn_opt
        
        SOC_p_tensor_fit = torch.tensor([SOC_p_fit], dtype=torch.float32).to('cuda')
        SOC_n_tensor_fit = torch.tensor([SOC_n_fit], dtype=torch.float32).to('cuda')
        
        Up_fit = OCP_p(SOC_p_tensor_fit).item()
        Un_fit = OCP_n(SOC_n_tensor_fit).item()
        
        OCPp_curve_fit[j] = Up_fit
        OCPn_curve_fit[j] = Un_fit
        fitted_Voc[j] = Up_fit - Un_fit
        
        
        SOC_p_con = y0_con - Qt / Cp_con
        SOC_n_con = x0_con - Qt / Cn_con
        
        SOC_p_tensor_con = torch.tensor([SOC_p_con], dtype=torch.float32).to('cuda')
        SOC_n_tensor_con = torch.tensor([SOC_n_con], dtype=torch.float32).to('cuda')
        
        Up_con = OCP_p(SOC_p_tensor_con).item()
        Un_con = OCP_n(SOC_n_tensor_con).item()
        
        OCPp_curve_con[j] = Up_con
        OCPn_curve_con[j] = Un_con
        Con_Voc[j] = Up_con - Un_con
    
    # Plot results
    plt.figure(num=None, figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * 4.84, true_ocv_curve_test[i, :, 0] * 4.2, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * 4.84, predict_V, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * 4.84, Calculated_test_ocv[i, :]*4.2, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPn_curve_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, fitted_Voc, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPp_curve_fit, ':', label='Fitted OCPp',color='#78A040',linewidth=2)
    plt.plot(predict_Q * 4.84, OCPn_curve_fit, ':', label='Fitted OCPn',color='#ED98CC',linewidth=2)
    plt.xlabel('Q [Ah]')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
    # plt.grid(True)
    plt.show()
    
    
    dv_dq_orig = gradient( true_ocv_curve_test[i, :, 0] * 4.2, true_ocv_curve_test[i, :, 1] * 4.84)
    dv_dq_predict = gradient(predict_V, predict_Q* 4.84)
    dv_dq_fit = gradient(fitted_Voc, predict_Q* 4.84)
    dv_dq_calculate = gradient(Calculated_test_ocv[i, :]*4.2, predict_Q* 4.84)
    dv_dq_cathode_fit = gradient(OCPp_curve_fit, predict_Q* 4.84)
    dv_dq_anode_fit = gradient(OCPn_curve_fit, predict_Q* 4.84)
    dv_dq_cathode_con = gradient(OCPp_curve_con, predict_Q* 4.84)
    dv_dq_anode_con = gradient(OCPn_curve_con, predict_Q* 4.84)
    
    plt.figure(figsize=(8 / 2.54, 5 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * 4.84, dv_dq_orig, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_predict, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_calculate, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_cathode_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_anode_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_fit, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_cathode_fit, ':', label='Fitted OCPp', color='#78A040',linewidth=2)
    plt.plot(predict_Q * 4.84, dv_dq_anode_fit, ':', label='Fitted OCPn', color='#ED98CC',linewidth=2)
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    plt.ylim([-1.5,1.1])
    # plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
    plt.show()
    
    

#%%
column_labels = ["C/80", "C/40", "0.05 A", "C/10", "C/7", "C/5"] #
colors = ["#F1766D", "#f3a17c", "#ecc68c", "#b3c6bb", "#78A040", "#839DD1"] #
scatter_types = ["s", "o", "p","^", ">", "h"]
fig, ax = plt.subplots(figsize=(6.5/2.54, 5/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
for i in range(len(column_labels)):
    Cap_real = np.array(true_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    Cap_predict = np.array(predict_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    plt.scatter(Cap_real, Cap_predict, marker=scatter_types[i] , alpha=1, color=colors[i],  linewidth=0.0, s=12, edgecolors=None, label=column_labels[i])
plt.plot(np.linspace(2.8,5), np.linspace(2.8,5),'-',color='dimgray')
plt.xlabel('Real RPT capacity [Ah]',labelpad=1)
plt.ylabel('Predictions [Ah]',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
# plt.legend(frameon=False, labelspacing=0.01, fontsize=12, 
#            handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 1.05))
legend1 = plt.legend(handles=ax.collections[:3], labels=column_labels[:3], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 1.05))

legend2 = plt.legend(handles=ax.collections[3:], labels=column_labels[3:], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=-0.5, loc='lower right')
plt.gca().add_artist(legend1)
plt.show()


fig, ax = plt.subplots(figsize=(6.5/2.54, 5/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
for i in range(len(column_labels)):
    Cap_real = np.array(true_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    Cap_predict = np.array(measure_all_cap_rate[i])*nominal_capacity
    plt.scatter(Cap_real, Cap_predict, marker=scatter_types[i] , alpha=1, color=colors[i],  linewidth=0.0, s=12, edgecolors=None, label=column_labels[i])
# plt.plot(np.linspace(2.8,5), np.linspace(2.8,5),'-',color='dimgray')
plt.xlabel('Real RPT capacity [Ah]',labelpad=1)
plt.ylabel('Measurement [Ah]',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
# plt.legend(frameon=False, labelspacing=0.01, fontsize=10, 
#            handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 0.6))
plt.show()



colors = ["#F1766D", "#f3a17c", "#ecc68c", "#b3c6bb", "#78A040", "#839DD1"] #
plt.figure(num=None,figsize=(6.5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
for i in range(len(column_labels)):
    Cap_real = np.array(true_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    Cap_predict = np.array(predict_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    error_q = np.abs(np.array(predict_ocv_all_diff_rate[i])[:,:,1].reshape(-1)*nominal_capacity - np.array(true_ocv_all_diff_rate[i])[:,:,1].reshape(-1)*nominal_capacity)
    error_v = np.abs(np.array(predict_ocv_all_diff_rate[i])[:,:,0].reshape(-1)*4.2 - np.array(true_ocv_all_diff_rate[i])[:,:,0].reshape(-1)*4.2)
    sorted_error_q = np.sort(error_q)
    cumulative_distribution_q = np.linspace(0, 1, len(sorted_error_q))
    sorted_error_v = np.sort(error_v)
    cumulative_distribution_v = np.linspace(0, 1, len(sorted_error_v))
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(sorted_error_q, cumulative_distribution_q,color=colors[i], linewidth=2)
    plt.plot(sorted_error_v, cumulative_distribution_v, '--', color=colors[i], linewidth=2)
plt.axhline(y=0.95, color='grey', linestyle=':', linewidth=2, label='95%')
plt.plot([], [], color='grey', linestyle='-', linewidth=2, label='Q error')
plt.plot([], [], color='grey', linestyle='--', linewidth=2, label='V error')
plt.xlabel('Error')
plt.ylabel('Cumulative distribution')
plt.legend(frameon=False, labelspacing=0.01, loc='lower right')
plt.xlabel('Error')
plt.ylabel('Cumulative distribution')
# plt.legend(frameon=False,labelspacing=0.01,loc='lower right')
plt.show()


fig, ax = plt.subplots(figsize=(6.5/2.54, 5/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
for i in range(len(column_labels)):
    Cap_real = np.array(true_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    Cap_predict = np.array(predict_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    plt.scatter(Cap_real, Cap_predict, marker=scatter_types[i] , alpha=1, color=colors[i],  linewidth=0.0, s=12, edgecolors=None, label=column_labels[i])
plt.plot(np.linspace(2.8,5), np.linspace(2.8,5),'-',color='dimgray')
plt.xlabel('Real RPT capacity [Ah]',labelpad=1)
plt.ylabel('Predictions [Ah]',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)


font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#8da0cd','#87d5c6'])
feature_names = ["Cp", "Cn", "x0", "y0"]
x_labels = ["1", "2", "3", "4", "5"]
for i in range(len(column_labels)):
    fig, ax = plt.subplots(figsize=(6.5/2.54, 4/2.54), dpi=600)
    plt.tick_params(bottom=False, left=False)

    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.imshow(shap_all_Q_diff_rate[i], aspect='auto', cmap=cmap2, interpolation='nearest')
    plt.yticks(ticks=np.arange(len(feature_names)), labels=feature_names)
    plt.xticks(ticks=np.arange(len(x_labels)), labels=x_labels)
    plt.colorbar()
    # plt.title('Mean Absolute SHAP Values (Capacity)')
    plt.xlabel('Test round')
    # plt.xticks([]) 
    # plt.ylabel('Features')
    plt.tight_layout()
    plt.show()

#%%
c_rate_idx = 0
c_rate_soh = np.array(true_ocv_all_diff_rate[c_rate_idx])[:,-1,1]
c_rate_soh_pred = np.array(predict_ocv_all_diff_rate[c_rate_idx])[:,-1,1]
c_rate_soh_diff = abs(c_rate_soh_pred-c_rate_soh)


targets = [1, 0.92, 0.84, 0.77]
closest_indices = []

for target in targets:
    mask = (np.abs(c_rate_soh - target) <= 0.01)  # 
    if np.any(mask):
        min_diff_index = np.argmin(c_rate_soh_diff[mask])  #
        actual_index = np.where(mask)[0][min_diff_index]  # 
        closest_indices.append(actual_index)


# colors = ["#F1766D", "#f3a17c", "#ecc68c", "#b3c6bb"]
colors = ["#F1766D",  "#ecc68c", "#b3c6bb",  "#839DD1"] #
fig, ax = plt.subplots(figsize=(7/ 2.54, 5 / 2.54), dpi=600)
plt.tick_params(bottom=False, left=False)


for idx, i in enumerate(closest_indices):
    plt.plot(
        np.array(true_ocv_all_diff_rate[c_rate_idx])[i, :, 1] * nominal_capacity,
        np.array(true_ocv_all_diff_rate[c_rate_idx])[i, :, 0] * 4.2,
        alpha=1, color=colors[idx], label=f"{c_rate_soh[i]:.2f} SOH measure"
    )


for idx, i in enumerate(closest_indices):
    plt.plot(
        np.array(predict_ocv_all_diff_rate[c_rate_idx])[i, :, 1] * nominal_capacity,
        np.array(predict_ocv_all_diff_rate[c_rate_idx])[i, :, 0] * 4.2,
        '--', alpha=1, color=colors[idx], label=f"{c_rate_soh[i]:.2f} SOH predict"
    )


plt.xlabel('Q [Ah]', labelpad=1)
plt.ylabel('Voltage [V]', labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
# legend1 = plt.legend(handles=ax.collections[:3], labels=column_labels[:3], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 1.05))

# legend2 = plt.legend(handles=ax.collections[3:], labels=column_labels[3:], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=-0.5, loc='lower right')
# plt.gca().add_artist(legend1)
# plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
plt.show()



fig, ax = plt.subplots(figsize=(7/ 2.54, 5 / 2.54), dpi=600)
plt.tick_params(bottom=False, left=False)


for idx, i in enumerate(closest_indices):
    dv_dq_orig = gradient( np.array(true_ocv_all_diff_rate[c_rate_idx])[i, :, 0] * 4.2, np.array(true_ocv_all_diff_rate[c_rate_idx])[i, :, 1] * nominal_capacity)
    plt.plot(
        np.array(true_ocv_all_diff_rate[c_rate_idx])[i, :, 1] * nominal_capacity,
        dv_dq_orig,
        alpha=1, color=colors[idx], label=f"{c_rate_soh[i]:.2f} SOH measure"
    )


for idx, i in enumerate(closest_indices):
    dv_dq_predict = gradient(np.array(predict_ocv_all_diff_rate[c_rate_idx])[i, :, 0] * 4.2, np.array(predict_ocv_all_diff_rate[c_rate_idx])[i, :, 1] * nominal_capacity)
    plt.plot(
        np.array(predict_ocv_all_diff_rate[c_rate_idx])[i, :, 1] * nominal_capacity,
        dv_dq_predict,
        '--', alpha=1, color=colors[idx], label=f"{c_rate_soh[i]:.2f} SOH predict"
    )


plt.xlabel('Q [Ah]', labelpad=1)
plt.ylabel('dV/dQ [V/Ah]', labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
# legend1 = plt.legend(handles=ax.collections[:3], labels=column_labels[:3], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=-0.5, loc='upper left', bbox_to_anchor=(-0.05, 1.05))

# legend2 = plt.legend(handles=ax.collections[3:], labels=column_labels[3:], frameon=False, labelspacing=0.01, fontsize=10, handletextpad=-0.5, loc='lower right')
# plt.gca().add_artist(legend1)
plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
plt.ylim([-1,-0.00])
plt.show()


#%%
colors = ["#F1766D", "#f3a17c", "#ecc68c", "#b3c6bb", "#78A040", "#839DD1"] #
font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
categories = ["C/80", "C/40", "0.05 A", "C/10", "C/7", "C/5"]
data = []
for i, cat in enumerate(categories):
    Cap_predict = np.array(predict_ocv_all_diff_rate[i])[:,-1,1]*nominal_capacity
    Cp = Cp_all_diff_rate[i].reshape(-1)
    Cn = Cn_all_diff_rate[i].reshape(-1)
    x0 = x0_all_diff_rate[i].reshape(-1)
    y0 = y0_all_diff_rate[i].reshape(-1)
    for j in range(len(Cap_predict)):
        data.append([cat, Cap_predict[j], Cp[j]*4.84, Cn[j]*4.84, x0[j], y0[j]])
        
df = pd.DataFrame(data, columns=["RPT current", "Capacity [Ah]", "Cp", "Cn", "x0", "y0"])

fig, ax = plt.subplots(figsize=(4 / 2.54, 5 / 2.54), dpi=600)
sns.set_style("white") 
ax.grid(False)  
ax.grid(axis='y', color="lightgray", linestyle='-', linewidth=0.7) 
for i, category in enumerate(categories):
    subset = df[df["RPT current"] == category]
    
    sns.violinplot(
        x="Capacity [Ah]", y="RPT current", data=subset, bw_adjust=0.3, cut=0, linewidth=0.5, 
        density_norm="width", hue="RPT current", palette={category: colors[i]}, inner=None, split=True,
        legend=False, alpha=0.8,
    )
    
    # Box plot with updated color and box properties
    sns.boxplot(
        x="Capacity [Ah]", y="RPT current", data=subset, width=0.2, showcaps=True, 
        boxprops={'facecolor': 'none', 'edgecolor': colors[i], 'linewidth': 1},
        whiskerprops={'color': colors[i], 'linewidth': 1.5},
        saturation=1, showfliers=False, zorder=1
    )
    
    # Strip plot with updated grayscale handling
    sns.stripplot(
        x="Capacity [Ah]", y="RPT current", data=subset, size=2, color=colors[i], 
        jitter=0.15, edgecolor="auto", linewidth=0.5, zorder=5, alpha=0.6,
    )
    
sns.despine(left=True)
plt.xlabel("Capacity [Ah]")
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
fig, ax = plt.subplots(figsize=(4 / 2.54, 5 / 2.54), dpi=600)
sns.set_style("white") 
ax.grid(False)  
ax.grid(axis='y', color="lightgray", linestyle='-', linewidth=0.7) 
for i, category in enumerate(categories):
    subset = df[df["RPT current"] == category]
    
    sns.violinplot(
        x="Cp", y="RPT current", data=subset, bw=0.3, cut=0, linewidth=0.5, 
        scale="width", palette=[colors[i]], inner=None, split=True, alpha=0.8,
    )
    sns.boxplot(
        x="Cp", y="RPT current", data=subset, width=0.2, showcaps=True, 
        boxprops={'facecolor': 'none', 'edgecolor': colors[i], 'linewidth': 1},
        whiskerprops={'color': colors[i], 'linewidth': 1.5},
        saturation=1, showfliers=False, zorder=1
    )
    
    sns.stripplot(
        x="Cp", y="RPT current", data=subset, size=2, color=colors[i], 
        jitter=0.15, edgecolor="gray", linewidth=0.5, zorder=5, alpha=0.6,
    )
ax.set_ylabel("")
ax.set_yticklabels([])
ax.yaxis.set_ticks([])
sns.despine(left=True)
plt.xlabel("Cp [Ah]")
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
fig, ax = plt.subplots(figsize=(4 / 2.54, 5 / 2.54), dpi=600)
sns.set_style("white") 
ax.grid(False)  
ax.grid(axis='y', color="lightgray", linestyle='-', linewidth=0.7) 
for i, category in enumerate(categories):
    subset = df[df["RPT current"] == category]
    
    sns.violinplot(
        x="Cn", y="RPT current", data=subset, bw=0.3, cut=0, linewidth=0.5, 
        scale="width", palette=[colors[i]], inner=None, split=True, alpha=0.8,
    )
    sns.boxplot(
        x="Cn", y="RPT current", data=subset, width=0.2, showcaps=True, 
        boxprops={'facecolor': 'none', 'edgecolor': colors[i], 'linewidth': 1},
        whiskerprops={'color': colors[i], 'linewidth': 1.5},
        saturation=1, showfliers=False, zorder=1
    )
    
    sns.stripplot(
        x="Cn", y="RPT current", data=subset, size=2, color=colors[i], 
        jitter=0.15, edgecolor="gray", linewidth=0.5, zorder=5, alpha=0.6,
    )
ax.set_ylabel("")
ax.set_yticklabels([])
ax.yaxis.set_ticks([])
sns.despine(left=True)
plt.xlabel("Cn [Ah]")
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
fig, ax = plt.subplots(figsize=(4 / 2.54, 5 / 2.54), dpi=600)
sns.set_style("white") 
ax.grid(False)  
ax.grid(axis='y', color="lightgray", linestyle='-', linewidth=0.7) 
for i, category in enumerate(categories):
    subset = df[df["RPT current"] == category]
    
    sns.violinplot(
        x="x0", y="RPT current", data=subset, bw=0.3, cut=0, linewidth=0.5, 
        scale="width", palette=[colors[i]], inner=None, split=True, alpha=0.8,
    )
    sns.boxplot(
        x="x0", y="RPT current", data=subset, width=0.2, showcaps=True, 
        boxprops={'facecolor': 'none', 'edgecolor': colors[i], 'linewidth': 1},
        whiskerprops={'color': colors[i], 'linewidth': 1.5},
        saturation=1, showfliers=False, zorder=1
    )
    
    sns.stripplot(
        x="x0", y="RPT current", data=subset, size=2, color=colors[i], 
        jitter=0.15, edgecolor="gray", linewidth=0.5, zorder=5, alpha=0.6,
    )
ax.set_ylabel("")
ax.set_yticklabels([])
ax.yaxis.set_ticks([])
sns.despine(left=True)
plt.xlabel("x0")
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
fig, ax = plt.subplots(figsize=(4 / 2.54, 5 / 2.54), dpi=600)
sns.set_style("white") 
ax.grid(False)  
ax.grid(axis='y', color="lightgray", linestyle='-', linewidth=0.7) 
for i, category in enumerate(categories):
    subset = df[df["RPT current"] == category]
    
    sns.violinplot(
        x="y0", y="RPT current", data=subset, bw=0.3, cut=0, linewidth=0.5, 
        scale="width", palette=[colors[i]], inner=None, split=True, alpha=0.8,
    )
    sns.boxplot(
        x="y0", y="RPT current", data=subset, width=0.2, showcaps=True, 
        boxprops={'facecolor': 'none', 'edgecolor': colors[i], 'linewidth': 1},
        whiskerprops={'color': colors[i], 'linewidth': 1.5},
        saturation=1, showfliers=False, zorder=1
    )
    
    sns.stripplot(
        x="y0", y="RPT current", data=subset, size=2, color=colors[i], 
        jitter=0.15, edgecolor="gray", linewidth=0.5, zorder=5, alpha=0.6,
    )
ax.set_ylabel("")
ax.set_yticklabels([])
ax.yaxis.set_ticks([])
sns.despine(left=True)
plt.xlabel("y0")
plt.show()


#%%
font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
c5_data = df[df["RPT current"] == "C/5"]["Capacity [Ah]"].values
c40_data = df[df["RPT current"] == "C/40"]["Capacity [Ah]"].values

min_length = min(len(c5_data), len(c40_data))
c5_data = c5_data[:min_length]
c40_data = c40_data[:min_length]
fig, ax = plt.subplots(figsize=(5/2.54, 5/2.54), dpi=600)
sns.set_style("white")
ax.scatter(c5_data, c40_data, color=colors[0], s=20, alpha=0.7, edgecolor='w', linewidth=0.5)
ax.set_xlabel("C/5 Capacity [Ah]", fontproperties=font_properties)
ax.set_ylabel("C/40 Capacity [Ah]", fontproperties=font_properties)
plt.tight_layout()
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
c5_data = df[df["RPT current"] == "C/5"]["Cp"].values
c40_data = df[df["RPT current"] == "C/40"]["Cp"].values

min_length = min(len(c5_data), len(c40_data))
c5_data = c5_data[:min_length]
c40_data = c40_data[:min_length]
fig, ax = plt.subplots(figsize=(5/2.54, 5/2.54), dpi=600)
sns.set_style("white")
ax.scatter(c5_data, c40_data, color=colors[1], s=20, alpha=0.7, edgecolor='w', linewidth=0.5)
ax.set_xlabel("C/5 Cp[Ah]", fontproperties=font_properties)
ax.set_ylabel("C/40 Cp [Ah]", fontproperties=font_properties)
plt.tight_layout()
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
c5_data = df[df["RPT current"] == "C/5"]["Cn"].values
c40_data = df[df["RPT current"] == "C/40"]["Cn"].values

min_length = min(len(c5_data), len(c40_data))
c5_data = c5_data[:min_length]
c40_data = c40_data[:min_length]
fig, ax = plt.subplots(figsize=(5/2.54, 5/2.54), dpi=600)
sns.set_style("white")
ax.scatter(c5_data, c40_data, color=colors[2], s=20, alpha=0.7, edgecolor='w', linewidth=0.5)
ax.set_xlabel("C/5 Cn [Ah]", fontproperties=font_properties)
ax.set_ylabel("C/40 Cn [Ah]", fontproperties=font_properties)
plt.tight_layout()
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
c5_data = df[df["RPT current"] == "C/5"]["x0"].values
c40_data = df[df["RPT current"] == "C/40"]["x0"].values

min_length = min(len(c5_data), len(c40_data))
c5_data = c5_data[:min_length]
c40_data = c40_data[:min_length]
fig, ax = plt.subplots(figsize=(5/2.54, 5/2.54), dpi=600)
sns.set_style("white")
ax.scatter(c5_data, c40_data, color=colors[3], s=20, alpha=0.7, edgecolor='w', linewidth=0.5)
ax.set_xlabel("C/5 x0", fontproperties=font_properties)
ax.set_ylabel("C/40 x0", fontproperties=font_properties)
plt.tight_layout()
plt.show()

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})
c5_data = df[df["RPT current"] == "C/5"]["y0"].values
c40_data = df[df["RPT current"] == "C/40"]["y0"].values
min_length = min(len(c5_data), len(c40_data))
c5_data = c5_data[:min_length]
c40_data = c40_data[:min_length]
fig, ax = plt.subplots(figsize=(5/2.54, 5/2.54), dpi=600)
sns.set_style("white")
ax.scatter(c5_data, c40_data, color=colors[4], s=20, alpha=0.7, edgecolor='w', linewidth=0.5)
ax.set_xlabel("C/5 y0", fontproperties=font_properties)
ax.set_ylabel("C/40 y0", fontproperties=font_properties)
plt.tight_layout()
plt.show()


#%%

predicted_ocv_curve_test = np.concatenate(predict_ocv_all_diff_rate, axis=0)
Calculated_test_ocv = np.concatenate(calculated_ocv_all_diff_rate, axis=0)
true_ocv_curve_test =  np.concatenate(true_ocv_all_diff_rate, axis=0)
Cp = np.concatenate(Cp_all_diff_rate, axis=0)
Cn = np.concatenate(Cn_all_diff_rate, axis=0)
x0 = np.concatenate(x0_all_diff_rate, axis=0)
y0 = np.concatenate(y0_all_diff_rate, axis=0)
Cap_real = true_ocv_curve_test[:,-1,1] 
Cap_predict = predicted_ocv_curve_test[:,-1,1] 

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
# cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#f3a17c','#27B2AF'])

RMSE_predicted,MAE_predicted = error_evaluation(Cap_predict.reshape(-1),Cap_real.reshape(-1))
residuals = Cap_predict-Cap_real
ss_res = np.sum(residuals**2)
ss_tot = np.sum((Cap_real - np.mean(Cap_real))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


error = np.abs(Cap_real - Cap_predict)
error_relative = np.abs(Cap_real - Cap_predict)/(Cap_real)
sorted_error = np.sort(error)
cumulative_distribution = np.linspace(0, 1, len(sorted_error))
sorted_error_relative = np.sort(error_relative)
cumulative_distribution_relative = np.linspace(0, 1, len(sorted_error_relative))

plt.figure(num=None,figsize=(6.5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.axhline(y=0.95, color='#daaa89', linestyle='--', linewidth=2, label='95%')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(100*sorted_error, cumulative_distribution, label='Absolute error', color='#5ca7d1', linewidth=2)
plt.plot(100*sorted_error_relative, cumulative_distribution_relative, label='Relative error', color='#f1b2c2', linewidth=2)
plt.plot()
plt.xlabel('SOH error [%]')
plt.ylabel('Cumulative distribution')
plt.legend(frameon=False,labelspacing=0.01,loc='lower right')
plt.show()

error_q = np.abs(predicted_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity - true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)
error_v = np.abs(predicted_ocv_curve_test[:,:,0].reshape(-1)*4.2 - true_ocv_curve_test[:,:,0].reshape(-1)*4.2)
sorted_error_q = np.sort(error_q)
cumulative_distribution_q = np.linspace(0, 1, len(sorted_error_q))
sorted_error_v = np.sort(error_v)
cumulative_distribution_v = np.linspace(0, 1, len(sorted_error_v))

plt.figure(num=None,figsize=(6.5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.axhline(y=0.95, color='#daaa89', linestyle='--', linewidth=2, label='95%')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(sorted_error_q, cumulative_distribution_q, label='Q error', color='#5ca7d1', linewidth=2)
plt.plot(sorted_error_v, cumulative_distribution_v, label='V error', color='#f1b2c2', linewidth=2)
plt.plot()
plt.xlabel('Error')
plt.ylabel('Cumulative distribution')
plt.legend(frameon=False,labelspacing=0.01,loc='lower right')
plt.show()



error_matrix = abs(predicted_ocv_curve_test[:,:,1]*nominal_capacity-true_ocv_curve_test[:,:,1]*nominal_capacity)
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix, cmap=cmap, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(predicted_ocv_curve_test[0, :, 0]) - 1, 10).astype(int)  # Select 10 evenly spaced ticks
voltage_values = predicted_ocv_curve_test[0, :, 0] * 4.2  # 
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Voltage [V]')
plt.ylabel('Test Sample')
plt.show()

RMSE_predicted,MAE_predicted = error_evaluation(np.array(predicted_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity),
                                                np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity))

residuals = np.array(predicted_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)-np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity) - np.mean(np.array(true_ocv_curve_test[:,:,1].reshape(-1)*nominal_capacity)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


error_matrix2 = abs(4.2*predicted_ocv_curve_test[:,1:,0]-4.2*true_ocv_curve_test[:,1:,0])
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix2, cmap=cmap, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(true_ocv_curve_test[0, 1:, 1]) - 1, 10).astype(int)  # Select 10 evenly spaced ticks
voltage_values = true_ocv_curve_test[0, 1:, 1] * nominal_capacity  #
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Q [Ah]')
# plt.ylabel('Test Sample Index')
ax.set_yticklabels([])
ax.set_yticks([])
plt.show()

RMSE_predicted,MAE_predicted = error_evaluation(predicted_ocv_curve_test[:,1:,0].reshape(-1)*4.2,
                                                true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)

residuals = np.array(predicted_ocv_curve_test[:,1:,0].reshape(-1)*4.2)-np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2) - np.mean(np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)

cmap2 = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#facaa9','#558c83'])
error_matrix3 = abs(4.2*Calculated_test_ocv[:,1:]-4.2*true_ocv_curve_test[:,1:,0])
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix3, cmap=cmap2, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(true_ocv_curve_test[0, 1:, 1]) - 1, 10).astype(int)  # Select 10 evenly spaced ticks
voltage_values = true_ocv_curve_test[0, 1:, 1] * nominal_capacity  # 
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Q [Ah]')
# plt.ylabel('Test Sample Index')
ax.set_yticklabels([])
ax.set_yticks([])
plt.show()


RMSE_predicted,MAE_predicted = error_evaluation(np.array(Calculated_test_ocv[:,1:].reshape(-1)*4.2),
                                                true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
residuals = np.array(Calculated_test_ocv[:,1:].reshape(-1)*4.2)-np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2) - np.mean(np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)





