# -*- coding: utf-8 -*-
"""
"""

# from beep import structure
import os
import numpy as np
from numpy import gradient
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from sklearn import metrics
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as mcolors
import time
from pyswarm import pso
from scipy.signal import savgol_filter

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})

#%%
def openBioFile(path):
    with open(path, 'rt', encoding="ISO-8859-1") as data:
        next(data)
        headerLines = next(data)
        nbHeaderLines = int(headerLines.split(":")[-1][1:4])
        df = pd.read_csv(path, sep="\t", skiprows=nbHeaderLines-1, encoding="ISO-8859-1")
    return df


ne_data_Co40 = openBioFile(r"..\006_half_cell\ZABC-A-097_Gr3_Co40_20220819_CD7.txt")
pe_data_Co40 = openBioFile(r"..\006_half_cell\ZABC-A-097_NCA8_Co40_20221020_CD8.txt")

cc_charge = 4
cc_discharge = 7
pe_charge_Co40 = pe_data_Co40[(pe_data_Co40["Ns"] == cc_discharge)]
ne_charge_Co40 = ne_data_Co40[(ne_data_Co40["Ns"] == cc_charge)]
pe_capacity = pe_charge_Co40["Capacity/mA.h"].values
pe_voltage = pe_charge_Co40["Ecell/V"].values
pe_voltage = pe_voltage[::-1]
pe_voltage_smoothed = savgol_filter(pe_voltage,window_length=251,polyorder=1)
ne_capacity = ne_charge_Co40["Capacity/mA.h"].values
ne_voltage = ne_charge_Co40["Ecell/V"].values
ne_voltage = ne_voltage[::-1]
ne_voltage_smoothed = savgol_filter(ne_voltage,window_length=251,polyorder=1)

x_new = np.linspace(0, 1, 1000)
interp_ocv_c = interp1d(np.linspace(0, 1, len(pe_capacity)), pe_capacity, kind='linear')
interp_ocv_v = interp1d(np.linspace(0, 1, len(pe_voltage_smoothed)), pe_voltage_smoothed, kind='linear')
pe_capacity = interp_ocv_c(x_new)
pe_voltage_smoothed = interp_ocv_v(x_new)

interp_ocv_c = interp1d(np.linspace(0, 1, len(ne_capacity)), ne_capacity, kind='linear')
interp_ocv_v = interp1d(np.linspace(0, 1, len(ne_voltage_smoothed)), ne_voltage_smoothed, kind='linear')
ne_capacity = interp_ocv_c(x_new)
ne_voltage_smoothed = interp_ocv_v(x_new)

OCPn_SOC = torch.tensor(ne_capacity/ne_capacity[-1], dtype=torch.float32).to('cuda')  # 
OCPn_V = torch.tensor(ne_voltage_smoothed, dtype=torch.float32).to('cuda')
OCPp_SOC = torch.tensor(pe_capacity/pe_capacity[-1], dtype=torch.float32).to('cuda')
OCPp_V = torch.tensor(pe_voltage_smoothed, dtype=torch.float32).to('cuda')



def torch_interpolate(x, xp, fp):
    indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1)
    x0 = xp[indices - 1]
    x1 = xp[indices]
    f0 = fp[indices - 1]
    f1 = fp[indices]
    return f0 + (f1 - f0) * (x - x0) / (x1 - x0)

def OCP_p(SOC_p):
    SOC_p = SOC_p.to(OCPp_SOC.device)  
    return torch_interpolate(SOC_p, OCPp_SOC, OCPp_V)

def OCP_n(SOC_n):
    SOC_n = SOC_n.to(OCPn_SOC.device)  
    return torch_interpolate(SOC_n, OCPn_SOC, OCPn_V)


#%%

def read_dynamic_data(file_name, idx_seed, num_points=200):
    inputs = []
    outputs = []
    start_v = []
    end_v = []
    cap_pari_work = []
    soc_range_work = []
    cell_name = []
    
    dynamic_data = pd.read_csv(file_name, low_memory=False)
    Capacity_all = dynamic_data['Normalized Capacity (nominal capacity unit)'].values
    Voltage_all = dynamic_data['Volts'].values
    EFC_all = np.zeros_like(Capacity_all)  
    cumulative_offset = 0  
    for i in range(len(Capacity_all)):
        if abs(Capacity_all[i])>1000:
            Capacity_all[i] = Capacity_all[i-1]
            
        if i > 0 and abs(Capacity_all[i]) < abs(Capacity_all[i - 1]):  
            cumulative_offset += abs(Capacity_all[i - 1])  
        EFC_all[i] = abs(Capacity_all[i]) + cumulative_offset  
    EFC_all = EFC_all / 2  
    
    grouped = dynamic_data.groupby("Cyc#")
    Voltage_cycles = [group["Volts"].values for _, group in grouped]
    Capacity_cycles = [group["Normalized Capacity (nominal capacity unit)"].values for _, group in grouped]
    Step_cycles = [group["Step"].values for _, group in grouped]
    Current_cycles = [group["Normalized Current (C-rate)"].values for _, group in grouped]
    Time_cycles = [group["Test (Sec)"].values for _, group in grouped]
    EFC_cycles = [EFC_all[group.index] for _, group in grouped]
    
    selected_cycles = []
    for i in range(len(Time_cycles)-1):
        if len(Time_cycles[i]) > 20000 and (Time_cycles[i][-1] - Time_cycles[i][0]) > 200000 and len(Time_cycles[i+1]) > 2000:
            selected_cycles.append(i)
    
    # print("diagnostic cycles:", selected_cycles)
    
    # plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
    # plt.ion()
    # plt.rcParams['xtick.direction'] = 'in'
    # plt.rcParams['ytick.direction'] = 'in'
    # plt.tick_params(top='on', right='on', which='both')
    # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    # plt.plot(EFC_all)
    
    
    for rpt_data_cycle in selected_cycles[:-1]:
        for round_idx in range(3):
            
            set_random_seed(idx_seed+round_idx)
            # plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
            # plt.ion()
            # plt.rcParams['xtick.direction'] = 'in'
            # plt.rcParams['ytick.direction'] = 'in'
            # plt.tick_params(top='on', right='on', which='both')
            # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
            # plt.plot(np.array(Time_cycles[rpt_data_cycle]),np.array(Voltage_cycles[rpt_data_cycle]))
            # plt.show()
            
            #output C/40 discharge
            
            I = np.array(Current_cycles[rpt_data_cycle])
            condition = (I == -0.025)
            indices = np.where(condition)[0]
            C040_segment = []
            current_segment = [indices[0]] if len(indices) > 0 else []
            for i in range(1, len(indices)):
                if indices[i] == indices[i - 1] + 1: 
                    current_segment.append(indices[i])
                else:
                  
                    if len(current_segment) > len(C040_segment):
                        C040_segment = current_segment
                    current_segment = [indices[i]]
            if len(current_segment) > len(C040_segment):
                C040_segment = current_segment
            condition = (np.array(Step_cycles[rpt_data_cycle]) == np.array(Step_cycles[rpt_data_cycle][C040_segment[0]]))
            C040_segment = np.where(condition)[0]
            
            for i in range(len(C040_segment)):
                # print(i)
                if np.array(Current_cycles[rpt_data_cycle])[C040_segment[i]]<0 and np.array(Capacity_cycles[rpt_data_cycle])[C040_segment[i]]>0:
                    break
            
            for j in range(i,len(C040_segment)-1):
                if np.array(Current_cycles[rpt_data_cycle])[C040_segment[j]]>=0 and np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[j]]<=2.801:
                    break
                
            for k in range(i,j):
                if np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k+1]]>np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k]] and np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k]]<=2.801:
                    break
            
            
            C040_segment = C040_segment[i:k]
            discharge_capacity = np.array(Capacity_cycles[rpt_data_cycle])[C040_segment]
            discharge_voltage = np.array(Voltage_cycles[rpt_data_cycle])[C040_segment]/4.2
            
            # print(i,k,j,4.2*discharge_voltage[-1],4.2*discharge_voltage[-2],4.2*discharge_voltage[-3] )
            # plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
            # plt.ion()
            # plt.rcParams['xtick.direction'] = 'in'
            # plt.rcParams['ytick.direction'] = 'in'
            # plt.tick_params(top='on', right='on', which='both')
            # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
            # plt.plot(np.array(Time_cycles[rpt_data_cycle])[C040_segment],np.array(Voltage_cycles[rpt_data_cycle])[C040_segment])
            # plt.show()
            
                   
            
            
            x_new = np.linspace(0, 1, num_points)
            interp_ocv_v = interp1d(np.linspace(0, 1, len(discharge_voltage)), discharge_voltage, kind='linear')
            interp_ocv_q = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity, kind='linear')
            discharge_voltage_interp = interp_ocv_v(x_new)
            discharge_capacity_interp = interp_ocv_q(x_new)
            discharge_capacity_diff = discharge_capacity_interp[1:] - discharge_capacity_interp[ :-1]
            
            
            # work cycle rpt_cycle + 1
            I = np.array(Current_cycles[rpt_data_cycle+1])
            V = np.array(Voltage_cycles[rpt_data_cycle+1])
            Q = np.array(Capacity_cycles[rpt_data_cycle+1])
            EFC = np.array(EFC_cycles[rpt_data_cycle+1])/500
            
            # print(i,k,j,4.2*discharge_voltage[-1],4.2*discharge_voltage[-2],4.2*discharge_voltage[-3] )
            # plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
            # plt.ion()
            # plt.rcParams['xtick.direction'] = 'in'
            # plt.rcParams['ytick.direction'] = 'in'
            # plt.tick_params(top='on', right='on', which='both')
            # plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
            # plt.plot(np.array(Time_cycles[rpt_data_cycle])[C040_segment],np.array(Voltage_cycles[rpt_data_cycle])[C040_segment])
            # plt.show()
            
            for i in range(len(V)):
                if V[i+1]<=V[i] and I[i+1]<I[i]:
                    charge_end_idx = i
                    break
            
            for i in range(charge_end_idx, len(V) ):
                if all(V[i+1:i+100]<V[i]):
                    charge_end_idx = i
                    break
            
            for k in range(charge_end_idx, len(V) ):
                if all(I[k:]>=0 ):
                    break
            # print(i,k)
            end_discharge = k
            start_index = random.randint(charge_end_idx, charge_end_idx+ int(0.3 * (end_discharge- charge_end_idx)))
            end_index = random.randint(charge_end_idx+int(0.7 * (end_discharge- charge_end_idx)), end_discharge)
            
            partial_discharge_voltage = V[start_index:end_index] / 4.2
            
            partial_discharge_capacity = Q[start_index:end_index]
            partial_discharge_capacity = abs(partial_discharge_capacity)
            check_value = 0
            cumulative_offset = 0  
            cumulative_discharge = np.zeros_like(partial_discharge_capacity)
            
            for i in range(len(partial_discharge_capacity)):
                if abs(partial_discharge_capacity[i])>1000:
                    partial_discharge_capacity[i] = partial_discharge_capacity[i-1]
                    
                if abs(partial_discharge_capacity[i])==0:  
                    cumulative_offset += abs(partial_discharge_capacity[i - 1])  
                cumulative_discharge[i] = abs(partial_discharge_capacity[i]) + cumulative_offset  
            
            partial_discharge_capacity = cumulative_discharge-cumulative_discharge[0]
            
            EFCs = EFC[start_index:end_index]
            
            # if rpt_data_cycle==selected_cycles[2] and round_idx==0 :
            #     plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
            #     plt.ion()
            #     plt.rcParams['xtick.direction'] = 'in'
            #     plt.rcParams['ytick.direction'] = 'in'
            #     plt.tick_params(top='on', right='on', which='both')
            #     plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
            #     plt.plot(partial_discharge_capacity, partial_discharge_voltage)
                
            #     plt.show()
            
            interp_voltage = interp1d(np.linspace(0, 1, len(partial_discharge_voltage)), partial_discharge_voltage, kind='linear')
            interp_capacity = interp1d(np.linspace(0, 1, len(partial_discharge_capacity)), partial_discharge_capacity, kind='linear')
            interp_EFCs = interp1d(np.linspace(0, 1, len(EFCs)), EFCs, kind='linear')

            partial_discharge_voltage_interp = interp_voltage(x_new)
            partial_discharge_capacity_interp = interp_capacity(x_new)
            partial_discharge_EFCs_interp = interp_EFCs(x_new)
            
           
            # start_v.append(charge_voltage[0]*4.2)
            # end_v.append(charge_voltage[end_index-start_index-1]*4.2)
            cap_pari_work.append(np.stack([max(partial_discharge_capacity), max(discharge_capacity)], axis=-1))
            
            # Generate partial discharge (200, 3))
            input_data = np.stack([partial_discharge_voltage_interp, partial_discharge_capacity_interp, partial_discharge_EFCs_interp], axis=-1) #
            output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
            
            # plt.plot(discharge_voltage_interp*4.2,discharge_capacity_interp)
            
            soc_range_work.append(max(partial_discharge_capacity)/max(discharge_capacity))
            start_v.append(partial_discharge_voltage[0]*4.2)
            end_v.append(partial_discharge_voltage[end_index-start_index-1]*4.2)
            # Add the processed sample to the list
            inputs.append(input_data)
            outputs.append(output_data)    
            cell_name.append(file_name)
    # plt.show()        
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)
    
    
    return inputs, outputs, soc_range_work, start_v, end_v, cell_name
        
    
folder_loc = '..\\dynamicdata'
folder_loc = os.path.abspath(folder_loc)
file_list = [f for f in os.listdir(folder_loc)]

test_batteries = [f for f in file_list if (int(f.split('_')[-1].split('.')[0]) % 3 == 1 or 
                                            int(f.split('_')[-1].split('.')[0]) % 3 == 2)]
train_batteries = [f for f in file_list if int(f.split('_')[-1].split('.')[0]) % 3 == 0]

# Construct training and test datasets
train_inputs, train_outputs = [], []
test_inputs, test_outputs = [], []
train_start_v, train_end_v, train_soc_range_work= [], [], []
test_start_v, test_end_v, test_soc_range_work= [], [], []
test_cell_names =[]
train_cell_names = []
# Process the battery data in the training set
plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

for i in range(0,len(train_batteries)): # train_batteries all_batteries 
    battery = train_batteries[i]
    # battery = 'Publishing_data_raw_data_cell_093.csv'
    print("Train cell", battery)
    file_name = os.path.join(folder_loc, battery)
    input_data, output_data, soc_range_work_data, start_v_data, end_v_data, cell_name = read_dynamic_data(file_name,i)
    train_inputs.append(input_data)
    train_outputs.append(output_data)
    train_start_v.append(start_v_data)
    train_end_v.append(end_v_data)
    train_soc_range_work.append(soc_range_work_data)
    train_cell_names.append(cell_name)
    plt.plot(input_data[:,-1,2]*500,output_data[:,-1,1], 'grey')
# plt.plot(np.arange(0,1500,100),0.8*np.ones(len(np.arange(0,1500,100))),'--', color='blue')
plt.xlabel('EFCs')
plt.ylabel('SOH ')
plt.show()
plt.show()    


plt.figure(num=None,figsize=(8/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# Process the battery data in the test set
for i in range(0,len(test_batteries)):
    battery = test_batteries[i]
    # set_random_seed(i)
    print("Test cell", battery)
    file_name = os.path.join(folder_loc, battery)
    input_data, output_data, soc_range_work_data, start_v_data, end_v_data, cell_name = read_dynamic_data(file_name,i)
    test_inputs.append(input_data)
    test_outputs.append(output_data)
    test_start_v.append(start_v_data)
    test_end_v.append(end_v_data)
    test_soc_range_work.append(soc_range_work_data)
    test_cell_names.append(cell_name)
    plt.plot(input_data[:,-1,2]*500,output_data[:,-1,1], 'grey')

plt.xlabel('EFCs')
plt.ylabel('SOH ')
plt.show()
# Convert to NumPy arrays
train_inputs = np.concatenate(train_inputs, axis=0)
train_outputs = np.concatenate(train_outputs, axis=0)
test_inputs = np.concatenate(test_inputs, axis=0)
test_outputs = np.concatenate(test_outputs, axis=0)

train_start_v = np.concatenate(train_start_v, axis=0)
train_end_v = np.concatenate(train_end_v, axis=0)
train_soc_range_work = np.concatenate(train_soc_range_work, axis=0)
test_start_v = np.concatenate(test_start_v, axis=0)
test_end_v = np.concatenate(test_end_v, axis=0)
test_soc_range_work = np.concatenate(test_soc_range_work, axis=0)

print(f"Training data shape: {train_inputs.shape}, {train_outputs.shape}")
print(f"Test data shape: {test_inputs.shape}, {test_outputs.shape}")

train_cell_names = np.concatenate(train_cell_names, axis=0)
test_cell_names = np.concatenate(test_cell_names, axis=0)

#%%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class EncoderMLP(nn.Module):
    def __init__(self, num_feature=3, hidden_size=128, output_size=64):
        super(EncoderMLP, self).__init__()
        input_size = 200 * num_feature  # 
        self.fc1 = nn.Linear(input_size, int(1*hidden_size))
        self.fc2 = nn.Linear(int(1*hidden_size), int(1*hidden_size))
        self.fc3 = nn.Linear(int(1*hidden_size), output_size)

        self.cp_out = nn.Linear(output_size, 1)
        self.cn_out = nn.Linear(output_size, 1)
        self.x0_out = nn.Linear(output_size, 1)
        self.y0_out = nn.Linear(output_size, 1)
        with torch.no_grad():
            self.cp_out.weight.fill_(0.0)
            self.cp_out.bias.fill_(0.9)
            self.cn_out.weight.fill_(0.0)
            self.cn_out.bias.fill_(0.9)
            self.x0_out.weight.fill_(0.0)
            self.x0_out.bias.fill_(0.9)
            self.y0_out.weight.fill_(0.0)
            self.y0_out.bias.fill_(0.9)
        
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # Flatten (n, 200*num_feature)
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        
        Cp = torch.relu(self.cp_out(x)) 
        Cn = torch.relu(self.cn_out(x)) 
        x0 = torch.relu(self.x0_out(x)) 
        y0 = torch.relu(self.y0_out(x)) 
        return Cp, Cn, x0, y0


class Decoder(nn.Module):
    def __init__(self, num_points=200):
        super(Decoder, self).__init__()
        self.num_points = num_points  # 
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fcv = nn.Linear(256, num_points)  # 
        self.fcq = nn.Linear(256, num_points)  #
    
    def forward(self, Cp, Cn, x0, y0):
        
        x = torch.cat((Cp, Cn, x0, y0), dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        V_OCV_curve = self.fcv(x)  
        Q_OCV_curve = self.fcq(x) 
        OCV_Q_curve = torch.stack((V_OCV_curve, Q_OCV_curve), dim=-1)
        return OCV_Q_curve


class BatteryModel(nn.Module):
    def __init__(self,encoder_name, num_points=200):
        super(BatteryModel, self).__init__()
        if encoder_name == 'MLP':
            self.encoder = EncoderMLP()
        else:
            raise ValueError(f"Unknown encoder name: {encoder_name}")
            
        self.decoder = Decoder(num_points=num_points)

    def forward(self, x):
        
        Cp, Cn, x0, y0= self.encoder(x)
        
        OCV_Q_curve = self.decoder(Cp, Cn, x0, y0)
        # Now calculate the OCV using the equation
        SOC_p = y0 - (OCV_Q_curve[:,:,1] / Cp)  # Ensure no division by zero
        SOC_n = x0 - (OCV_Q_curve[:,:,1]  / Cn)
        Calculated_OCV = OCP_p(SOC_p) - OCP_n(SOC_n)
        Calculated_OCV = Calculated_OCV/4.2

        return OCV_Q_curve, Cp, Cn, x0, y0, Calculated_OCV, SOC_p, SOC_n#, Qt


def regression_loss(predicted_ocv_Q_curve, true_ocv_Q_curve):
    return nn.MSELoss()(predicted_ocv_Q_curve[:,:,0], true_ocv_Q_curve[:,:,0])+\
        nn.MSELoss()(predicted_ocv_Q_curve[:,:,1], true_ocv_Q_curve[:,:,1])

def physics_loss( Cp, Cn, x0, y0, predicted_ocv_Q_curve,true_ocv_Q_curve, eps=1e-3):
    # print(y0)
    SOC_p = y0 - (predicted_ocv_Q_curve[:,:,1] / Cp)  # Ensure no division by zero
    SOC_n = x0 - (predicted_ocv_Q_curve[:,:,1] / Cn)
    # print(SOC_p)

    
    Calculated_OCV = OCP_p(SOC_p) - OCP_n(SOC_n)
    Calculated_OCV = Calculated_OCV/4.2

    predict_range_loss = (
        torch.mean(torch.relu(Calculated_OCV[:, :] - 1))+
        torch.mean(torch.relu(2.7/4.2-Calculated_OCV[:, :]))+
        torch.mean(torch.relu(predicted_ocv_Q_curve[:,:,0] - 1))+
        torch.mean(torch.relu(2.7/4.2-predicted_ocv_Q_curve[:,:,0]))+
        torch.mean(torch.relu(-predicted_ocv_Q_curve[:, :, 1]))+
        torch.mean(torch.relu(predicted_ocv_Q_curve[:, 0, 1]))
        )
        

    soc_p_constraint = torch.mean(torch.relu(SOC_p - 1.0) + torch.relu(0 - SOC_p))
    soc_n_constraint = torch.mean(torch.relu(SOC_n - 1.0) + torch.relu(0 - SOC_n))

    
    # Cp and Cn in reasonable ranges
    cp_constraint = torch.mean(torch.relu(Cp - 1.1) + torch.relu(0.1 - Cp)) #torch.mean(torch.relu(-Cp))
    cn_constraint = torch.mean(torch.relu(Cn - 1.1) + torch.relu(0.1 - Cn))#torch.mean(torch.relu(-Cn))
    
    
    ocv_loss = nn.MSELoss()(Calculated_OCV, predicted_ocv_Q_curve[:,:,0])
   
    
    # Total physical loss
    total_physics_loss = (
        ocv_loss 
    
    )
    
    total_constrain_loss = (

        + cp_constraint 
        + cn_constraint 
        + predict_range_loss
        + soc_p_constraint 
        + soc_n_constraint  # Uncomment if SOC constraints are also desired
        
    )
    
    return total_physics_loss, total_constrain_loss


def total_loss(predicted_ocv_Q_curve, true_ocv_Q_curve, Cp, Cn, x0, y0, phys_weight, reg_weight, constraint_weight): #, reg_weight=1.0, phys_weight=1.0
    reg_loss = regression_loss(predicted_ocv_Q_curve, true_ocv_Q_curve)
    phys_loss, constraint_loss = physics_loss(Cp, Cn, x0, y0, predicted_ocv_Q_curve,true_ocv_Q_curve)
    total_loss_value = reg_weight * reg_loss + phys_weight * phys_loss + constraint_weight * constraint_loss
    return total_loss_value, reg_loss, phys_loss, constraint_loss

def error_evaluation(prediction, real_value):
    RMSE = metrics.mean_squared_error(prediction, real_value)**0.5
    MAE = metrics.mean_absolute_error(prediction, real_value)
    print("RMSE, MAE", RMSE, MAE)
    return RMSE, MAE


#%%

set_random_seed(123)
train_inputs_tensor = torch.tensor(train_inputs, dtype=torch.float32).to(device)
train_outputs_tensor = torch.tensor(train_outputs, dtype=torch.float32).to(device)  #
test_inputs_tensor = torch.tensor(test_inputs, dtype=torch.float32).to(device)
test_outputs_tensor = torch.tensor(test_outputs, dtype=torch.float32).to(device)  # 

num_samples = train_inputs_tensor.shape[0]
shuffled_indices = torch.randperm(num_samples)
train_inputs_shuffled = train_inputs_tensor[shuffled_indices]
train_outputs_shuffled = train_outputs_tensor[shuffled_indices]


val_split = 0.2  # 
num_train_samples = int((1 - val_split) * num_samples)

train_inputs_train = train_inputs_shuffled[:num_train_samples]
train_outputs_train = train_outputs_shuffled[:num_train_samples]
train_inputs_val = train_inputs_shuffled[num_train_samples:]
train_outputs_val = train_outputs_shuffled[num_train_samples:]


#%%
set_random_seed(123)

patience = 2000  # 
no_improvement = 0  # 

num_points = 200  # 
model = BatteryModel(encoder_name = 'MLP', num_points=num_points).to(device)  # 
model.load_state_dict(torch.load('best_model.pth', weights_only=True))


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=100)

best_loss = float('inf')  # 
best_model_path = 'best_model_fine_tune_dynamic_dis.pth' 
max_epoch = 20000
phys_weight = 1
reg_weight = 1
constraint_weight = 0.1
total_start_time = time.time()

for epoch in range(max_epoch):
  
    model.train()
    predicted_Q_ocv_curve, Cp, Cn, x0, y0,Calculated_Q_ocv, SOC_p, SOC_n = model(train_inputs_train)
    total_loss_value, reg_loss, phys_loss,constraint_loss = total_loss(predicted_Q_ocv_curve, train_outputs_train,
                                                       Cp, Cn, x0, y0, phys_weight, reg_weight,constraint_weight) #, 
    
    # Backward pass and optimization
    optimizer.zero_grad()
    total_loss_value.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    
    # Validation phase
    model.eval()
    with torch.no_grad():
        predicted_val_Q_ocv, Cp, Cn, x0, y0,Calculated_val_Q_ocv, SOC_p, SOC_n = model(train_inputs_val)
        val_total_loss_value, val_reg_loss, val_phys_loss,val_constraint_loss = total_loss(predicted_val_Q_ocv, train_outputs_val,
                                                                       Cp, Cn, x0, y0, phys_weight, reg_weight,constraint_weight) #, reg_weight=1.0, physics_weight=5
    optimizer.step()
    scheduler.step(val_total_loss_value)
    # Early stopping logic with best validation loss tracking
    if val_total_loss_value < best_loss:
        # time.sleep(0.5)
        best_loss = val_total_loss_value
        no_improvement = 0  # Reset counter if improvement
        torch.save(model.state_dict(), best_model_path)  # Save the best model
        print(f'Epoch {epoch}, '
              f'Training Loss: {total_loss_value.item()} (Reg: {reg_loss.item()}, Phys: {phys_loss.item()}, Con: {constraint_loss.item()})\n'
              f'Validation Loss: {val_total_loss_value.item()} (Reg: {val_reg_loss.item()}, Phys: {val_phys_loss.item()}, Con: {val_constraint_loss.item()}), LR: {scheduler.optimizer.param_groups[0]["lr"]:.6f}')
    
    else:
        no_improvement += 1
        
    print(f'Epoch {epoch}, '
          f'Training Loss: {total_loss_value.item()} (Reg: {reg_loss.item()}, Phys: {phys_loss.item()})\n'
          f'Validation Loss: {val_total_loss_value.item()} (Reg: {val_reg_loss.item()}, Phys: {val_phys_loss.item()}), LR: {scheduler.optimizer.param_groups[0]["lr"]:.6f}')

    # Early stopping based on patience
    if no_improvement >= patience or scheduler.optimizer.param_groups[0]["lr"] <5e-7:
        print(f"Early stopping at epoch {epoch}, Validation loss has not improved in {patience} epochs.")
        break
    
total_end_time = time.time()
print('Time for training:', total_end_time-total_start_time)

#%%
#%
model.load_state_dict(torch.load(best_model_path, weights_only=True))
model.eval()  # 

with torch.no_grad():
    predicted_test_ocv, Cp, Cn, x0, y0,Calculated_test_ocv, SOC_p, SOC_n = model(test_inputs_tensor)
predicted_ocv_curve_test = predicted_test_ocv.cpu().numpy()
Calculated_test_ocv = Calculated_test_ocv.cpu().numpy()
true_ocv_curve_test = test_outputs_tensor.cpu().numpy()
Cp = Cp.cpu().numpy()
Cn = Cn.cpu().numpy()
x0 = x0.cpu().numpy()
y0 = y0.cpu().numpy()
Cap_real = true_ocv_curve_test[:,-1,1] 
Cap_predict = predicted_ocv_curve_test[:,-1,1] 

#%%

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
orig_color_values = np.abs(Cap_real-np.array(Cap_predict))
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max()+1)
color_values = norm(orig_color_values)
# ax.tick_params(top='on', right='on', which='both')
scatter = plt.scatter(Cap_real, Cap_predict, c=np.abs(Cap_real-Cap_predict), alpha=1, cmap=cmap, marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(Cap_real,Cap_real,'-',color='dimgray')
divider = make_axes_locatable(ax)
cax = ax.inset_axes([0.7, 0.08, 0.05, 0.35])
cbar = plt.colorbar(scatter, cax=cax)
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)
plt.xlabel('Real',labelpad=1)
plt.ylabel('Predictions',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
plt.show()


RMSE_predicted,MAE_predicted = error_evaluation(Cap_predict.reshape(-1),Cap_real.reshape(-1))
residuals = Cap_predict-Cap_real
ss_res = np.sum(residuals**2)
ss_tot = np.sum((Cap_real - np.mean(Cap_real))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


pre_error=Cap_predict-Cap_real

fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
plt.tick_params(bottom=False, left=False)
scatter = plt.scatter(test_start_v, test_end_v, c=np.abs(Cap_real-Cap_predict), alpha=0.8, cmap=cmap, marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.xlabel('Start voltage [V]',labelpad=1)
plt.ylabel('End voltage [V]',labelpad=1)
ax.tick_params(axis='both', which='both', length=0)
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)
plt.show()


# plt.figure(num=None,figsize=(6/2.54,5/2.54),dpi=600)
# plt.ion()
# plt.rcParams['xtick.direction'] = 'in'
# plt.rcParams['ytick.direction'] = 'in'
# plt.tick_params(top='on', right='on', which='both')
# plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# plt.plot(Cap_real, test_inputs_tensor.cpu().numpy()[:,-1,1] ,'o')
# # plt.plot(Cap_real, Cap_real)
# plt.xlabel('Real')
# plt.ylabel('Measurement')
# plt.show()
# fixed_voltage_points = np.linspace(2.8, 4.2, num_points)
RMSE_prior,MAE_prior = error_evaluation(test_inputs_tensor.cpu().numpy()[:,-1,1].reshape(-1),Cap_real.reshape(-1))

residuals = np.array(test_inputs_tensor.cpu().numpy()[:,-1,2].reshape(-1))-Cap_real
ss_res = np.sum(residuals**2)
ss_tot = np.sum((Cap_real - np.mean(Cap_real))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)



error = np.abs(Cap_real - Cap_predict)
error_relative = np.abs(Cap_real - Cap_predict)/(Cap_real)
sorted_error = np.sort(error)
cumulative_distribution = np.linspace(0, 1, len(sorted_error))
sorted_error_relative = np.sort(error_relative)
cumulative_distribution_relative = np.linspace(0, 1, len(sorted_error_relative))

abs_idx = np.abs(cumulative_distribution - 0.95).argmin()
abs_error_x = 100 * sorted_error[abs_idx]
abs_error_y = cumulative_distribution[abs_idx]
rel_idx = np.abs(cumulative_distribution_relative - 0.95).argmin()
rel_error_x = 100 * sorted_error_relative[rel_idx]
rel_error_y = cumulative_distribution_relative[rel_idx]

plt.figure(num=None,figsize=(6.5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.axhline(y=0.95, color='#daaa89', linestyle='--', linewidth=2, label='95%')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(100*sorted_error, cumulative_distribution, label='Absolute error', color='#5ca7d1', linewidth=2)
plt.plot(100*sorted_error_relative, cumulative_distribution_relative, label='Relative error', color='#f1b2c2', linewidth=2)
plt.scatter([abs_error_x], [abs_error_y], color='#5ca7d1', s=10, zorder=5)
plt.axvline(x=abs_error_x, color='gray', linestyle='--',  alpha=0.5,linewidth=1)
plt.scatter([rel_error_x], [rel_error_y], color='#f1b2c2', s=10, zorder=5)
plt.axvline(x=rel_error_x, color='gray', linestyle='--', alpha=0.5, linewidth=1)
plt.xlabel('SOH error [%]')
plt.ylabel('Cumulative distribution')
legend = plt.legend(frameon=True,labelspacing=0.01,loc='lower right',facecolor='white', edgecolor='white')
legend.get_frame().set_linewidth(0) 
plt.show()


error_matrix = abs(predicted_ocv_curve_test[:,:,1]-true_ocv_curve_test[:,:,1])
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix, cmap=cmap, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(predicted_ocv_curve_test[0, :, 0]) - 1, 10).astype(int)  # Select 10 evenly spaced ticks
voltage_values = predicted_ocv_curve_test[0, :, 0] * 4.2  # Use the first row for voltage values (adjust based on your data)
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Voltage [V]')
plt.ylabel('Test Sample')
plt.show()

RMSE_predicted,MAE_predicted = error_evaluation(np.array(predicted_ocv_curve_test[:,:,1].reshape(-1)),
                                                np.array(true_ocv_curve_test[:,:,1].reshape(-1)))

residuals = np.array(predicted_ocv_curve_test[:,:,1].reshape(-1))-np.array(true_ocv_curve_test[:,:,1].reshape(-1))
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,:,1].reshape(-1)) - np.mean(np.array(true_ocv_curve_test[:,:,1].reshape(-1))))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)


error_matrix2 = abs(4.2*predicted_ocv_curve_test[:,1:,0]-4.2*true_ocv_curve_test[:,1:,0])
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix2, cmap=cmap, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(true_ocv_curve_test[0, 1:, 1]) - 1, 10).astype(int)  # Select 10 evenly spaced ticks
voltage_values = true_ocv_curve_test[0, 1:, 1]  # Use the first row for voltage values (adjust based on your data)
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Normalized Q')
# plt.ylabel('Test Sample Index')
ax.set_yticklabels([])
ax.set_yticks([])
plt.show()

RMSE_predicted,MAE_predicted = error_evaluation(predicted_ocv_curve_test[:,1:,0].reshape(-1)*4.2,
                                                true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)

residuals = np.array(predicted_ocv_curve_test[:,1:,0].reshape(-1)*4.2)-np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2) - np.mean(np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)

cmap2 = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#facaa9','#558c83'])
error_matrix3 = abs(4.2*Calculated_test_ocv[:,1:]-4.2*true_ocv_curve_test[:,1:,0])
plt.figure(num=None,figsize=(7/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax = sns.heatmap(error_matrix3, cmap=cmap2, xticklabels=8)
# Set custom tick labels for the x-axis (fixed_voltage_points)
x_ticks = np.linspace(0, len(true_ocv_curve_test[0, 1:, 1]) - 1, 10).astype(int)  # Select 10 evenly spaced ticks
voltage_values = true_ocv_curve_test[0, 1:, 1]  # Use the first row for voltage values (adjust based on your data)
# Format and set the x-axis tick labels
ax.set_xticks(x_ticks)
ax.set_xticklabels([f'{x:.2f}' for x in voltage_values[x_ticks]])
plt.xlabel('Normalized Q')
# plt.ylabel('Test Sample Index')
ax.set_yticklabels([])
ax.set_yticks([])
plt.show()


RMSE_predicted,MAE_predicted = error_evaluation(np.array(Calculated_test_ocv[:,1:].reshape(-1)*4.2),
                                                true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
residuals = np.array(Calculated_test_ocv[:,1:].reshape(-1)*4.2)-np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2) - np.mean(np.array(true_ocv_curve_test[:,1:,0].reshape(-1)*4.2)))**2)
r_squared = 1 - (ss_res / ss_tot)
print("R^2:", r_squared)

#%% 
for i in range(200, len(true_ocv_curve_test), 500):  # len(true_ocv_curve_test), 2570, 2590, 1
    print(test_cell_names[i])
    
    dv_dq_orig = gradient( true_ocv_curve_test[i, :, 0] * 4.2, true_ocv_curve_test[i, :, 1] )
    dv_dq_predict = gradient(predicted_ocv_curve_test[i, :,0]*4.2, predicted_ocv_curve_test[i, :,1])
    dv_dq_calculate = gradient(Calculated_test_ocv[i, :]*4.2, predicted_ocv_curve_test[i, :,1])
    
    
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6.2 / 2.54, 6 / 2.54), dpi=600, 
                                   gridspec_kw={'hspace': 0})  # 'hspace=0' 
    ax1.plot(test_inputs_tensor[i, :, 1].cpu().numpy(), test_inputs_tensor[i, :, 0].cpu().numpy() * 4.2, 
             label='Measure V-Q', color='#84c3b7', linestyle='-')
    ax1.plot(true_ocv_curve_test[i, :, 1], true_ocv_curve_test[i, :, 0]*4.2, '-', label='True OCV', color='grey',linewidth=2)
    ax1.plot(predicted_ocv_curve_test[i, :,1], predicted_ocv_curve_test[i, :,0]*4.2, '--', color='#e68b81',linewidth=2)
    ax1.plot(predicted_ocv_curve_test[i, :,1], Calculated_test_ocv[i, :]*4.2, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    ax1.set_ylabel('Voltage [V]')
    ax1.tick_params(top=True, right=True, which='both', direction='in')
    ax1.spines['bottom'].set_visible(False)  # 
    ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    ax1.set_xticks([])
    ax1.set_xlim([-0.1,1.1])

    ax2.plot(true_ocv_curve_test[i, :, 1], dv_dq_orig, '-', label='Measured OCV', color='grey',linewidth=2)
    ax2.plot(predicted_ocv_curve_test[i, :,1], dv_dq_predict, '--', label='Predicted OCV', color='#e68b81',linewidth=2)
    ax2.plot(predicted_ocv_curve_test[i, :,1] , dv_dq_calculate, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    ax2.set_xlabel('Normalized Q')
    ax2.set_ylabel('dV/dQ [V/1]')
    ax2.tick_params(top=True, right=True, which='both', direction='in')
    ax2.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    ax2.set_xlim([-0.1,1.1])
    ax2.set_ylim([-3.5,1.5])
    fig.align_labels() 
    
    plt.show()
    

selected_cells = []
#%%
error_matrix = (predicted_ocv_curve_test[:,:,1]-true_ocv_curve_test[:,:,1])
error_matrix2 = (4.2*predicted_ocv_curve_test[:,1:,0]-4.2*true_ocv_curve_test[:,1:,0])
error_matrix3 = (4.2*Calculated_test_ocv[:,1:]-4.2*true_ocv_curve_test[:,1:,0])
plt.figure(num=None,figsize=(5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# ax = sns.heatmap(error_matrix3, cmap=cmap2, xticklabels=8)
ax = sns.kdeplot(error_matrix2.reshape(-1), fill=True, linewidth=1, alpha=0.5, color = '#F1766D',label = 'Predicted V [V]')
ax = sns.kdeplot(error_matrix3.reshape(-1), fill=True,  linewidth=1, alpha=0.5, color = '#b3c6bb',label = 'Derived V [V]')
ax = sns.kdeplot(error_matrix.reshape(-1), fill=True, linewidth=1, alpha=0.5, color= '#839DD1',label = 'Predicted Q')
plt.xlabel('Error')
plt.xlim([-0.08,0.08])
# plt.ylabel('Test Sample Index')
# ax.set_yticklabels([])
# ax.set_yticks([])
# plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
plt.show()


mae_error_matrix = np.mean(np.abs(error_matrix), axis=1)
mae_error_matrix2 = np.mean(np.abs(error_matrix2), axis=1)
mae_error_matrix3 = np.mean(np.abs(error_matrix3), axis=1)


#%
mae_data = {
    'Error Type': [],
    'Error Value': []
}

mae_data['Error Type'].extend(['Predicted V [V]'] * mae_error_matrix2.size)
mae_data['Error Value'].extend(mae_error_matrix2.flatten())

mae_data['Error Type'].extend(['Derived V [V]'] * mae_error_matrix3.size)
mae_data['Error Value'].extend(mae_error_matrix3.flatten())

mae_data['Error Type'].extend(['Predicted Q'] * mae_error_matrix.size)
mae_data['Error Value'].extend(mae_error_matrix.flatten())

mae_data_df = pd.DataFrame(mae_data)


plt.figure(num=None,figsize=(5/2.54,5/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
palette = {'Predicted V [V]': '#F1766D', 'Derived V [V]': '#b3c6bb', 'Predicted Q': '#839DD1'}

ax= sns.violinplot(x='Error Type', y='Error Value', data=mae_data_df, palette=palette)
plt.ylabel('MAE of each test')  
plt.xlabel('Variables [V or Q]')
ax.set_xticklabels([])
# plt.xticks(rotation=45) 
plt.show()

#%%

nominal_capacity = 1
# PSO optimization function
def pso_objective_function(params):
    Cp, Cn, x0, y0 = params
    fitted_Voc = np.zeros_like(predict_Q)
    OCPp_curve = np.zeros_like(predict_Q)
    OCPn_curve = np.zeros_like(predict_Q)
    for j, Qt in enumerate(predict_Q):
        SOC_p = y0 - Qt / Cp
        SOC_n = x0 - Qt / Cn
        
        SOC_p_tensor = torch.tensor([SOC_p], dtype=torch.float32).to('cuda')
        SOC_n_tensor = torch.tensor([SOC_n], dtype=torch.float32).to('cuda')
        
        Up = OCP_p(SOC_p_tensor).item()
        Un = OCP_n(SOC_n_tensor).item()
        
        OCPp_curve[j] = Up
        OCPn_curve[j] = Un
        fitted_Voc[j] = Up - Un

    return np.mean((predict_V - fitted_Voc) ** 2)  # MSE


min_row_index = np.argmin(np.mean(error_matrix, axis=1)+np.mean(error_matrix2, axis=1))
# min_row_index = np.argmin(np.mean(error_matrix, axis=1))
max_row_index = np.argmax(np.max(error_matrix, axis=1))
median_row_index = np.argmin(np.abs(np.mean(error_matrix, axis=1) - np.median(np.mean(error_matrix, axis=1)))+
                             np.abs(np.mean(error_matrix2, axis=1) - np.median(np.mean(error_matrix2, axis=1)))
                             )
# Loop through data points for PSO optimization
for i in range(1):  # 0, len(predicted_ocv_curve_test), 200
    i = min_row_index
    
    plt.figure(num=None,figsize=(10/2.54,6/2.54),dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(test_inputs_tensor[i, :, 1].cpu().numpy() * nominal_capacity, test_inputs_tensor[i, :, 0].cpu().numpy() * 4.2, 
             label='Measure V-Q', color='blue', linestyle='-')
    plt.plot(true_ocv_curve_test[i, :, 1]*nominal_capacity, true_ocv_curve_test[i, :, 0]*4.2, label='True OCV', color='grey', linestyle='-')
    plt.plot(predicted_ocv_curve_test[i, :,1]*nominal_capacity, predicted_ocv_curve_test[i, :,0]*4.2, label='Predicted OCV', color='#84c3b7', linestyle='--')
    plt.plot(predicted_ocv_curve_test[i, :,1]*nominal_capacity, Calculated_test_ocv[i, :]*4.2, label='Derived OCV',  color='#eaaa60', linestyle='-.')
    plt.xlabel('Normalized Q')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False,labelspacing=0.01)
    # plt.grid(True)
    plt.show()
    
    
    predict_Q = predicted_ocv_curve_test[i, :, 1]
    predict_V = predicted_ocv_curve_test[i, :, 0] * 4.2
    
    Cp_con = Cp[i].item()
    Cn_con = Cn[i].item()
    x0_con = x0[i].item()
    y0_con = y0[i].item()
    
    lb = [0.1, 0.1, 1e-3, 1e-3]  # Lower bounds
    ub = [1.1, 1.1, 1.0, 1.0]  # Upper bounds
    # Set PSO options, including max iterations to limit the run time
    options = {'maxiter': 20}
    optimized_params, fopt = pso(pso_objective_function, lb, ub, minstep=1e-5, minfunc=1e-5, **options)
    
    print(f"Constrianed parameters: Cp={Cp_con}, Cn={Cn_con}, x0={x0_con}, y0={y0_con}")
    print(f"PSO optimized parameters: Cp={optimized_params[0]}, Cn={optimized_params[1]}, x0={optimized_params[2]}, y0={optimized_params[3]}")
    
    Cp_opt, Cn_opt, x0_opt, y0_opt = optimized_params
    fitted_Voc = np.zeros_like(predict_Q)
    OCPp_curve_fit = np.zeros_like(predict_Q)
    OCPn_curve_fit = np.zeros_like(predict_Q)
    OCPp_curve_con = np.zeros_like(predict_Q)
    OCPn_curve_con = np.zeros_like(predict_Q)
    Con_Voc = np.zeros_like(predict_Q)
    for j, Qt in enumerate(predict_Q):
        SOC_p_fit = y0_opt - Qt / Cp_opt
        SOC_n_fit = x0_opt - Qt / Cn_opt
        
        SOC_p_tensor_fit = torch.tensor([SOC_p_fit], dtype=torch.float32).to('cuda')
        SOC_n_tensor_fit = torch.tensor([SOC_n_fit], dtype=torch.float32).to('cuda')
        
        Up_fit = OCP_p(SOC_p_tensor_fit).item()
        Un_fit = OCP_n(SOC_n_tensor_fit).item()
        
        OCPp_curve_fit[j] = Up_fit
        OCPn_curve_fit[j] = Un_fit
        fitted_Voc[j] = Up_fit - Un_fit
        
        
        SOC_p_con = y0_con - Qt / Cp_con
        SOC_n_con = x0_con - Qt / Cn_con
        
        SOC_p_tensor_con = torch.tensor([SOC_p_con], dtype=torch.float32).to('cuda')
        SOC_n_tensor_con = torch.tensor([SOC_n_con], dtype=torch.float32).to('cuda')
        
        Up_con = OCP_p(SOC_p_tensor_con).item()
        Un_con = OCP_n(SOC_n_tensor_con).item()
        
        OCPp_curve_con[j] = Up_con
        OCPn_curve_con[j] = Un_con
        Con_Voc[j] = Up_con - Un_con
    
    # Plot results
    plt.figure(num=None, figsize=(8 / 2.54, 6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * nominal_capacity, true_ocv_curve_test[i, :, 0] * 4.2, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, predict_V, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, Calculated_test_ocv[i, :]*4.2, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, OCPp_curve_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, OCPn_curve_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, fitted_Voc, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, OCPp_curve_fit, ':', label='Fitted OCPp',color='#78A040',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, OCPn_curve_fit, ':', label='Fitted OCPn',color='#ED98CC',linewidth=2)
    plt.xlabel('Normalized Q')
    plt.ylabel('Voltage [V]')
    # plt.title(f'Result on Testing Data (OCV Curve) - Sample {i+1}')
    plt.legend(frameon=False, labelspacing=0.1, loc='center left',  bbox_to_anchor=(1, 0.5))
    # plt.grid(True)
    plt.show()
    
    
    dv_dq_orig = gradient( true_ocv_curve_test[i, :, 0] * 4.2, true_ocv_curve_test[i, :, 1] * nominal_capacity)
    dv_dq_predict = gradient(predict_V, predict_Q* nominal_capacity)
    dv_dq_fit = gradient(fitted_Voc, predict_Q* nominal_capacity)
    dv_dq_calculate = gradient(Calculated_test_ocv[i, :]*4.2, predict_Q* nominal_capacity)
    dv_dq_cathode_fit = gradient(OCPp_curve_fit, predict_Q* nominal_capacity)
    dv_dq_anode_fit = gradient(OCPn_curve_fit, predict_Q* nominal_capacity)
    dv_dq_cathode_con = gradient(OCPp_curve_con, predict_Q* nominal_capacity)
    dv_dq_anode_con = gradient(OCPn_curve_con, predict_Q* nominal_capacity)
    
    plt.figure(figsize=(10 / 2.54, 6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    plt.plot(true_ocv_curve_test[i, :, 1] * nominal_capacity, dv_dq_orig, '-', label='Measured OCV', color='grey',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_predict, '--', label='Predicted OCV', color='#84c3b7',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_calculate, '-.', label='Derived OCV', color='#7da6c6',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_cathode_con, '-.', label='Derived OCPp', color='#eaaa60',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_anode_con, '-.', label='Derived OCPn', color='#b7b2d0',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_fit, ':', label='Fitted OCV', color='#e68b81',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_cathode_fit, ':', label='Fitted OCPp', color='#78A040',linewidth=2)
    plt.plot(predict_Q * nominal_capacity, dv_dq_anode_fit, ':', label='Fitted OCPn', color='#ED98CC',linewidth=2)
    plt.ylabel('dV/dQ [V/1]')
    plt.xlabel('Normalized Q')
    plt.ylim([-3.5,1.5])
    # plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
    plt.show()
    

    
    