import gzip
import json
import numpy as np
import os
from numpy import gradient
import torch
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
import re
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
from beep import structure

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#%%
# Load the saved JSON file
with gzip.open('data_all_cells.json.gz', 'rt', encoding='utf-8') as f:
    data_all_cells = json.load(f)
    
#%%

# Extract all battery names
all_batteries = list(data_all_cells.keys())

all_battery_num = [re.search(r'_(\d+)_', entry).group(1) for entry in all_batteries]
all_battery_num_stripped = [num.lstrip('0') for num in all_battery_num]
with open('train_cells.txt', 'r') as file:
    train_cells_content = file.read()

train_cells = train_cells_content.splitlines()
train_batteries_indices = [all_battery_num_stripped.index(cell) for cell in train_cells if cell in all_battery_num_stripped]
train_batteries = [all_batteries[i] for i in train_batteries_indices]

with open('test_cells.txt', 'r') as file:
    test_cells_content = file.read()

test_cells = test_cells_content.splitlines()
test_batteries_indices = [all_battery_num_stripped.index(cell) for cell in test_cells if cell in all_battery_num_stripped]
test_batteries = [all_batteries[i] for i in test_batteries_indices]

#%% Load half cell data and fitting function construction
folder_loc = '..\\006_half_cell'
file_name = 'anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv'
# file_path = f'{folder_loc}\\{file_name}'
file_path = f'{file_name}'
# file_path = os.path.join(folder_loc, file_name)
OCPn_data = pd.read_csv(file_path)

file_name = 'cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv'
# file_path = f'{folder_loc}\\{file_name}'
file_path = f'{file_name}'
# file_path = os.path.join(folder_loc, file_name)
OCPp_data = pd.read_csv(file_path)

OCPn_SOC = torch.tensor(OCPn_data['SOC_linspace'].values, dtype=torch.float32).to('cuda')  # 
OCPn_V = torch.tensor(OCPn_data['Voltage'].values, dtype=torch.float32).to('cuda')
OCPp_SOC = torch.tensor(OCPp_data['SOC_linspace'].values, dtype=torch.float32).to('cuda')
OCPp_V = torch.tensor(OCPp_data['Voltage'].values, dtype=torch.float32).to('cuda')


def torch_interpolate(x, xp, fp):
    indices = torch.searchsorted(xp, x).clamp(1, len(xp) - 1)
    x0 = xp[indices - 1]
    x1 = xp[indices]
    f0 = fp[indices - 1]
    f1 = fp[indices]
    return f0 + (f1 - f0) * (x - x0) / (x1 - x0)

def OCP_p(SOC_p):
    SOC_p = SOC_p.to(OCPp_SOC.device)  
    return torch_interpolate(SOC_p, OCPp_SOC, OCPp_V)

def OCP_n(SOC_n):
    SOC_n = SOC_n.to(OCPn_SOC.device)  
    return torch_interpolate(SOC_n, OCPn_SOC, OCPn_V)

orig_ocv_curve = OCPp_data['Voltage'].values-OCPn_data['Voltage'].values
ori_soc_curve = OCPn_data['SOC_linspace'].values
orig_ocpp_curve = OCPp_data['Voltage'].values
orig_ocpn_curve = OCPn_data['Voltage'].values

SOC_input = torch.tensor(ori_soc_curve, dtype=torch.float32).to('cuda')
fit_ocp_p = OCP_p(SOC_input).cpu().numpy()  # Cathode fit
fit_ocp_n = OCP_n(SOC_input).cpu().numpy()  # Anode fit
fit_ocv_curve = fit_ocp_p - fit_ocp_n  # Full cell OCV 

plt.figure(num=None,figsize=(10/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(ori_soc_curve, orig_ocpn_curve, color='#839DD1', linewidth=2, label='Anode')
plt.plot(ori_soc_curve, orig_ocpp_curve, color='#27B2AF', linewidth=2, label='Cathode')
plt.plot(ori_soc_curve,orig_ocv_curve, color= '#752a80', linewidth=2,  label='Full cell OCV')
plt.plot(ori_soc_curve, fit_ocp_n, ':', color= '#F1766D',linewidth=2,  label='Fit Anode')
plt.plot(ori_soc_curve, fit_ocp_p, ':', color='#E59F05', linewidth=2, label='Fit Cathode')
plt.plot(ori_soc_curve, fit_ocv_curve, ':', color='#78A040', linewidth=2, label='Fit Full cell OCV')
plt.ylabel('Voltage [V]')
plt.xlabel('SOC')
plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
plt.show()

interp_ocv = interp1d(np.linspace(0, 1, len(orig_ocv_curve)), orig_ocv_curve, kind='linear')
interp_ocv50 = interp_ocv(np.linspace(0, 1, 50))
interp_ocv100 = interp_ocv(np.linspace(0, 1, 100))
interp_ocv150 = interp_ocv(np.linspace(0, 1, 150))
interp_ocv200 = interp_ocv(np.linspace(0, 1, 200))
interp_ocv300 = interp_ocv(np.linspace(0, 1, 300))

interp_soc = interp1d(np.linspace(0, 1, len(ori_soc_curve)), ori_soc_curve, kind='linear')
interp_soc50 = interp_soc(np.linspace(0, 1, 50))
interp_soc100 = interp_soc(np.linspace(0, 1, 100))
interp_soc150 = interp_soc(np.linspace(0, 1, 150))
interp_soc200 = interp_soc(np.linspace(0, 1, 200))
interp_soc300 = interp_soc(np.linspace(0, 1, 300))

plt.figure(num=None,figsize=(10/2.54,6/2.54),dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
plt.plot(ori_soc_curve,orig_ocv_curve, color='#F1766D', linewidth=2, label='Full cell OCV')
plt.plot(interp_soc50,interp_ocv50,'--', color='#E59F05', linewidth=2, label='50 points')
plt.plot(interp_soc100,interp_ocv100,'-.', color='#839DD1', linewidth=2, label='100 points')
plt.plot(interp_soc150,interp_ocv150,':', color='#27B2AF', linewidth=2, label='150 points')
plt.plot(interp_soc200,interp_ocv200,'--', color='#752a80', linewidth=2, label='200 points')
plt.plot(interp_soc300,interp_ocv300,'-.', color='#78A040', linewidth=2, label='300 points')
plt.ylabel('Voltage [V]')
plt.xlabel('SOC')
plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
plt.show()

dv_dq_orig = gradient(orig_ocv_curve, ori_soc_curve)
dv_dq_50 = gradient(interp_ocv50, interp_soc50)
dv_dq_100 = gradient(interp_ocv100, interp_soc100)
dv_dq_150 = gradient(interp_ocv150, interp_soc150)
dv_dq_200 = gradient(interp_ocv200, interp_soc200)
dv_dq_300 = gradient(interp_ocv300, interp_soc300)

plt.figure(figsize=(10 / 2.54, 6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
plt.plot(ori_soc_curve, dv_dq_orig, color='#F1766D', linewidth=2, label='Full cell')
plt.plot(interp_soc50, dv_dq_50, '--', color='#E59F05', linewidth=2, label='50 points')
plt.plot(interp_soc100, dv_dq_100, '-.', color='#839DD1', linewidth=2, label='100 points')
plt.plot(interp_soc150, dv_dq_150, ':', color='#27B2AF', linewidth=2, label='150 points')
plt.plot(interp_soc200, dv_dq_200, '--', color='#752a80', linewidth=2, label='200 points')
plt.plot(interp_soc300, dv_dq_300, '-.', color='#78A040', linewidth=2, label='300 points')
plt.ylabel('dV/dSOC')
plt.xlabel('SOC')
plt.ylim([0,40])
plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
plt.show()

plt.figure(figsize=(10/ 2.54, 6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
plt.plot(ori_soc_curve, dv_dq_orig, color='#F1766D', linewidth=2, label='Full cell')
plt.plot(interp_soc50, dv_dq_50, '--', color='#E59F05', linewidth=2, label='50 points')
plt.plot(interp_soc100, dv_dq_100, '-.', color='#839DD1', linewidth=2, label='100 points')
plt.plot(interp_soc150, dv_dq_150, ':', color='#27B2AF', linewidth=2, label='150 points')
plt.plot(interp_soc200, dv_dq_200, '--', color='#752a80', linewidth=2, label='200 points')
plt.plot(interp_soc300, dv_dq_300, '-.', color='#78A040', linewidth=2, label='300 points')
plt.ylabel('dV/dSOC')
plt.xlabel('SOC')
plt.ylim([0.2,1.5])
plt.xlim([0.2,1.01])
plt.show()

dq_dv_orig = gradient(ori_soc_curve, orig_ocv_curve)
dq_dv_50 = gradient(interp_soc50, interp_ocv50 )
dq_dv_100 = gradient(interp_soc100, interp_ocv100)
dq_dv_150 = gradient(interp_soc150, interp_ocv150)
dq_dv_200 = gradient(interp_soc200, interp_ocv200)
dq_dv_300 = gradient(interp_soc300, interp_ocv300)

plt.figure(figsize=(10 / 2.54, 6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
plt.plot(orig_ocv_curve, dq_dv_orig, color='#F1766D', linewidth=2, label='Full cell')
plt.plot(interp_ocv50, dq_dv_50, '--', color='#E59F05', linewidth=2, label='50 points')
plt.plot(interp_ocv100, dq_dv_100, '-.', color='#839DD1', linewidth=2, label='100 points')
plt.plot(interp_ocv150, dq_dv_150, ':', color='#27B2AF', linewidth=2, label='150 points')
plt.plot(interp_ocv200, dq_dv_200, '--', color='#752a80', linewidth=2, label='200 points')
plt.plot(interp_ocv300, dq_dv_300, '-.', color='#78A040', linewidth=2, label='300 points')
plt.ylabel('dSOC/dV')
plt.xlabel('Voltage [V]')
plt.legend(fontsize=12, labelspacing=0.1, handletextpad=0.3,framealpha=0)
plt.show()

plt.figure(figsize=(10/ 2.54, 6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
plt.plot(orig_ocv_curve, dq_dv_orig, color='#F1766D', linewidth=2, label='Full cell')
plt.plot(interp_ocv50, dq_dv_50, '--', color='#E59F05', linewidth=2, label='50 points')
plt.plot(interp_ocv100, dq_dv_100, '-.', color='#839DD1', linewidth=2, label='100 points')
plt.plot(interp_ocv150, dq_dv_150, ':', color='#27B2AF', linewidth=2, label='150 points')
plt.plot(interp_ocv200, dq_dv_200, '--', color='#752a80', linewidth=2, label='200 points')
plt.plot(interp_ocv300, dq_dv_300, '-.', color='#78A040', linewidth=2, label='300 points')
plt.ylabel('dSOC/dV')
plt.xlabel('Voltage [V]')
plt.ylim([2.01,3.5])
plt.xlim([3.98,4.02])
plt.show()

# del OCPn_data, OCPp_data


#%%
import gzip
import json
import os

nominal_capacity = 4.84
def read_gzipped_json(file_path):
    try:
        with gzip.open(file_path, 'rt', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except (OSError, json.JSONDecodeError, gzip.BadGzipFile) as e:
        print(f"Error reading {file_path}: {e}")
        return None

# folder_loc = '..\\001_EoL_Data'
folder_loc = os.path.join('..', '001_EoL_Data')
file_list = [f for f in os.listdir(folder_loc) if f.endswith('.json.gz')]


file_name = file_list[10]
# Read JSON file contents
file_path = os.path.join(folder_loc, file_name)
data = read_gzipped_json(file_path)

cell_data = {}

# Step 1: Retrieve the protocol and store it into data_all_cells
protocol = data.get('protocol', 'Unknown')  # If protocol is missing, default to 'Unknown'
cell_data['protocol'] = protocol

# Step 2: Retrieve diagnostic_starts_at from structuring_parameters
diagnostic_starts_at = data.get('structuring_parameters', {}).get('diagnostic_available', {}).get('diagnostic_starts_at', [])
# cell_data['diagnostic_starts_at'] = diagnostic_starts_at

# Retrieve relevant data from raw_data
raw_data = data.get('raw_data', {})
cycle_index = raw_data.get('cycle_index', [])
test_time = raw_data.get('test_time', [])
voltage = raw_data.get('voltage', [])
current = raw_data.get('current', [])
charge_capacity = raw_data.get('charge_capacity', [])
discharge_capacity = raw_data.get('discharge_capacity', [])

time_show = test_time[185100:348500]
voltage_show = voltage[185100:348500]
current_show = current[185100:348500]
idx_end_dis = []
for i in range(1,len(time_show)-10):
    if voltage_show[i]<2.8 and voltage_show[i+1]>voltage_show[i] and voltage_show[i-1]>voltage_show[i] :
        idx_end_dis.append(i)
idx_end_dis = np.array(idx_end_dis)


fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12 / 2.54, 6 / 2.54), dpi=600, 
                               gridspec_kw={'hspace': 0})  # 'hspace=0' 
ax1.plot(time_show, current_show, color='#F1766D')
ax1.plot(time_show[idx_end_dis[2]:idx_end_dis[3]], current_show[idx_end_dis[2]:idx_end_dis[3]], label='Voltage', color='#78A040')
ax1.plot(time_show[idx_end_dis[3]:idx_end_dis[6]], current_show[idx_end_dis[3]:idx_end_dis[6]], label='Voltage', color='#E59F05')
ax1.plot(time_show[idx_end_dis[6]:idx_end_dis[7]], current_show[idx_end_dis[6]:idx_end_dis[7]], label='Voltage', color='#839DD1')
ax1.plot(time_show[idx_end_dis[7]:idx_end_dis[8]], current_show[idx_end_dis[7]:idx_end_dis[8]], label='Voltage', color='#27B2AF')
ax1.plot(time_show[idx_end_dis[8]:idx_end_dis[9]], current_show[idx_end_dis[8]:idx_end_dis[9]], label='Voltage', color='#752a80')
ax1.set_ylabel('Current [A]')
ax1.tick_params(top=True, right=True, which='both', direction='in')
ax1.spines['bottom'].set_visible(False)  # 
ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax1.set_xticks([])

ax2.plot(time_show, voltage_show, color='#F1766D')
ax2.plot(time_show[idx_end_dis[2]:idx_end_dis[3]], voltage_show[idx_end_dis[2]:idx_end_dis[3]], label='Voltage', color='#78A040')
ax2.plot(time_show[idx_end_dis[3]:idx_end_dis[6]], voltage_show[idx_end_dis[3]:idx_end_dis[6]], label='Voltage', color='#E59F05')
ax2.plot(time_show[idx_end_dis[6]:idx_end_dis[7]], voltage_show[idx_end_dis[6]:idx_end_dis[7]], label='Voltage', color='#839DD1')
ax2.plot(time_show[idx_end_dis[7]:idx_end_dis[8]], voltage_show[idx_end_dis[7]:idx_end_dis[8]], label='Voltage', color='#27B2AF')
ax2.plot(time_show[idx_end_dis[8]:idx_end_dis[9]], voltage_show[idx_end_dis[8]:idx_end_dis[9]], label='Voltage', color='#752a80')

ax2.set_xlabel('Time [s]')
ax2.set_ylabel('Voltage [V]')
ax2.tick_params(top=True, right=True, which='both', direction='in')
ax2.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
fig.align_labels() 
plt.show()

# Initialize EFC list
EFC_list = []
total_capacity_throughput = 0  # Initialize cumulative throughput

last_cycle_throughput = 0  # Used to accumulate previous cycle's throughput

# Step 3: Calculate EFC for each cycle
for idx, cycle in enumerate(cycle_index):
    # Current cycle throughput = current charge_capacity + discharge_capacity
    current_cycle_throughput = charge_capacity[idx] + discharge_capacity[idx]
    
    # Add the previous cycle's throughput
    total_capacity_throughput = last_cycle_throughput + current_cycle_throughput
    
    # Calculate the current EFC
    EFC = total_capacity_throughput / (2 * nominal_capacity)
    EFC_list.append(EFC)
    
    # If this is the last data point in the current cycle, save the throughput for the next cycle
    if idx == len(cycle_index) - 1 or cycle_index[idx + 1] != cycle:
        last_cycle_throughput = total_capacity_throughput

# Define the required cycle type labels
rpt_labels = ['reset', 'hppc', '02C', '1C', '2C', 'work']

# Step 4: Loop through diagnostic_starts_at list and adjust test_time
cell_data['rpt'] = []
for i, start in enumerate(diagnostic_starts_at):
    # Initialize rpt[i] structure
    rpt_data = {label: {'test_time': [], 'voltage': [], 'current': [], 'charge_capacity': [], 'discharge_capacity': [], 'EFC': []} for label in rpt_labels}
    
    # Loop through each cycle, find the cycles from start to start + 5
    for j, label in enumerate(rpt_labels):
        target_cycle = start + j
        
        # Find all rows in cycle_index that match target_cycle
        matching_indices = [idx for idx, cycle in enumerate(cycle_index) if cycle == target_cycle]
        
        # Extract all matching rows and adjust test_time to start from 0
        if matching_indices:
            first_test_time = test_time[matching_indices[0]]  # Use the first test_time as the baseline
            for idx in matching_indices:
                adjusted_test_time = test_time[idx] - first_test_time  # Adjust test_time to start from 0
                rpt_data[label]['test_time'].append(adjusted_test_time)
                rpt_data[label]['voltage'].append(voltage[idx])
                rpt_data[label]['current'].append(current[idx])
                rpt_data[label]['charge_capacity'].append(charge_capacity[idx])
                rpt_data[label]['discharge_capacity'].append(discharge_capacity[idx])
                rpt_data[label]['EFC'].append(EFC_list[idx])  # Add EFC for each data point
    
    # Store rpt[i] into cell_data
    cell_data['rpt'].append(rpt_data)


#%%
nominal_capacity = 4.84
def capacity_extract(data, idx_seed, num_points=200):
    """
    """
    nominal_capacity = 4.84
    discharge_capacity =[]
    EFCs =[]
    # print(data['protocol'])
    # Iterate over each 'rpt' and extract corresponding '1C', '2C' and '02C' data
    for rpt_data in data['rpt']:
        if '1C' in rpt_data and  '2C' in rpt_data and '02C' in rpt_data:
            # Extract the 1c discharge process (capacity and voltage) and perform interpolation
            discharge_capacity_1C = np.array(rpt_data['1C']['discharge_capacity'])
            discharge_voltage_1C = np.array(rpt_data['1C']['voltage'])
            EFCs_1C = np.array(rpt_data['1C']['EFC'])
            discharge_capacity_2C = np.array(rpt_data['2C']['discharge_capacity'])
            discharge_voltage_2C = np.array(rpt_data['2C']['voltage'])
            EFCs_2C = np.array(rpt_data['2C']['EFC'])
            discharge_capacity_02C = np.array(rpt_data['02C']['discharge_capacity'])
            discharge_voltage_02C = np.array(rpt_data['02C']['voltage'])
            EFCs_02C = np.array(rpt_data['02C']['EFC'])
            if discharge_capacity_02C.size == 0 or discharge_capacity_1C.size == 0 or discharge_capacity_2C.size == 0:
                continue
            discharge_capacity.append(np.stack([max(discharge_capacity_1C), max(discharge_capacity_2C), max(discharge_capacity_02C)], axis=-1))
            EFCs.append(np.stack([max(EFCs_1C), max(EFCs_2C), max(EFCs_02C)], axis=-1))
    discharge_capacity = np.array(discharge_capacity)  # Shape (num_samples, num_points, num_features)
    EFCs = np.array(EFCs)  #(num_samples, num_points)

    return discharge_capacity, EFCs

discharge_capacity_all =[]
EFCs_all =[]
# Process the battery data in the training set
colors = ["#839DD1","#78A040"]
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
fig, axes = plt.subplots(1, 3, figsize=(20/2.54,6/2.54),dpi=600, sharey=True)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

initial_cap = []
cap_eight_percent_02C = []
EFC_eight_percent_02C = []
cap_eight_percent_1C = []
EFC_eight_percent_1C = []
cap_eight_percent_2C = []
EFC_eight_percent_2C = []

cap_sevenfive_percent_02C = []
EFC_sevenfive_percent_02C = []
cap_sevenfive_percent_1C = []
EFC_sevenfive_percent_1C = []
cap_sevenfive_percent_2C = []
EFC_sevenfive_percent_2C = []

cap_seven_percent_02C = []
EFC_seven_percent_02C = []
cap_seven_percent_1C = []
EFC_seven_percent_1C = []
cap_seven_percent_2C = []
EFC_seven_percent_2C = []

#%
inter_points = 10000
for i in range(len(all_batteries)):  
    battery = all_batteries[i]
    print("Train cell", battery)
    
    discharge_capacity, EFCs = capacity_extract(data_all_cells[battery], i)
    discharge_capacity_all.append(discharge_capacity)
    EFCs_all.append(EFCs)
    initial_cap.append(discharge_capacity[0,:].reshape(-1,1))
    
    for k in range(3):
        for j in range(len(discharge_capacity[:, k])):
            if discharge_capacity[j, k]/4.84 < 0.75:
                break
        color_i = cmap(np.clip(np.max(EFCs[j, k] / 1700), 0, 1))
        
        axes[k].plot(EFCs[:, k], discharge_capacity[:, k], color=color_i)
        axes[k].set_xlabel('EFCs')
    
    interp_cap_1C = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity[:,0], kind='linear')
    interp_EFC_1C = interp1d(np.linspace(0, 1, len(EFCs)), EFCs[:,0], kind='linear')
    interp_cap_2C = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity[:,1], kind='linear')
    interp_EFC_2C = interp1d(np.linspace(0, 1, len(EFCs)), EFCs[:,1], kind='linear')
    interp_cap_02C = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity[:,2], kind='linear')
    interp_EFC_02C = interp1d(np.linspace(0, 1, len(EFCs)), EFCs[:,2], kind='linear')
    
    discharge_capacity_interp = np.zeros((inter_points,3))
    discharge_capacity_interp[:,0] = interp_cap_1C(np.linspace(0, 1,inter_points))
    discharge_capacity_interp[:,1] = interp_cap_2C(np.linspace(0, 1,inter_points))
    discharge_capacity_interp[:,2] = interp_cap_02C(np.linspace(0, 1,inter_points))
    
    EFCs_interp = np.zeros((inter_points,3))
    EFCs_interp[:,0] = interp_EFC_1C(np.linspace(0, 1,inter_points))
    EFCs_interp[:,1] = interp_EFC_2C(np.linspace(0, 1,inter_points))
    EFCs_interp[:,2] = interp_EFC_02C(np.linspace(0, 1,inter_points))
    
    for m in range(len(discharge_capacity_interp[:, 2])):
        if discharge_capacity_interp[m, 2]/4.84 < 0.8:
            cap_eight_percent_02C.append(discharge_capacity_interp[m, 2])
            EFC_eight_percent_02C.append(EFCs_interp[m, 2])
            cap_eight_percent_1C.append(discharge_capacity_interp[m, 0])
            EFC_eight_percent_1C.append(EFCs_interp[m, 0])
            cap_eight_percent_2C.append(discharge_capacity_interp[m, 1])
            EFC_eight_percent_2C.append(EFCs_interp[m, 1])
            break
        
    for m in range(len(discharge_capacity_interp[:, 2])):
        if discharge_capacity_interp[m, 2]/4.84 < 0.75:
            cap_sevenfive_percent_02C.append(discharge_capacity_interp[m, 2])
            EFC_sevenfive_percent_02C.append(EFCs_interp[m, 2])
            cap_sevenfive_percent_1C.append(discharge_capacity_interp[m, 0])
            EFC_sevenfive_percent_1C.append(EFCs_interp[m, 0])
            cap_sevenfive_percent_2C.append(discharge_capacity_interp[m, 1])
            EFC_sevenfive_percent_2C.append(EFCs_interp[m, 1])
            break
    
    for m in range(len(discharge_capacity_interp[:, 2])):
        if discharge_capacity_interp[m, 2]/4.84 < 0.7:
            cap_seven_percent_02C.append(discharge_capacity_interp[m, 2])
            EFC_seven_percent_02C.append(EFCs_interp[m, 2])
            cap_seven_percent_1C.append(discharge_capacity_interp[m, 0])
            EFC_seven_percent_1C.append(EFCs_interp[m, 0])
            cap_seven_percent_2C.append(discharge_capacity_interp[m, 1])
            EFC_seven_percent_2C.append(EFCs_interp[m, 1])
            break

    
axes[0].set_ylabel('Capacity [Ah]')
for ax in axes:
    ax.tick_params(axis='both', which='both', bottom=False, top=False, 
                   left=False, right=False)
plt.tight_layout()
plt.show()



#%%
initial_cap = np.concatenate(initial_cap, axis=1).T

cap_eight_percent_02C = np.array(cap_eight_percent_02C)
EFC_eight_percent_02C = np.array(EFC_eight_percent_02C)
cap_eight_percent_1C = np.array(cap_eight_percent_1C)
EFC_eight_percent_1C = np.array(EFC_eight_percent_1C)
cap_eight_percent_2C = np.array(cap_eight_percent_2C)
EFC_eight_percent_2C = np.array(EFC_eight_percent_2C)

cap_sevenfive_percent_02C = np.array(cap_sevenfive_percent_02C)
EFC_sevenfive_percent_02C = np.array(EFC_sevenfive_percent_02C)
cap_sevenfive_percent_1C = np.array(cap_sevenfive_percent_1C)
EFC_sevenfive_percent_1C = np.array(EFC_sevenfive_percent_1C)
cap_sevenfive_percent_2C = np.array(cap_sevenfive_percent_2C)
EFC_sevenfive_percent_2C = np.array(EFC_sevenfive_percent_2C)

cap_seven_percent_02C = np.array(cap_seven_percent_02C)
EFC_seven_percent_02C = np.array(EFC_seven_percent_02C)
cap_seven_percent_1C = np.array(cap_seven_percent_1C)
EFC_seven_percent_1C = np.array(EFC_seven_percent_1C)
cap_seven_percent_2C = np.array(cap_seven_percent_2C)
EFC_seven_percent_2C = np.array(EFC_seven_percent_2C)


fig, ax = plt.subplots(figsize=(7/2.54, 6/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.tick_params(bottom=False, left=False)
plt.hist(initial_cap[:,0],bins=18,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#27B2AF',edgecolor='black',label='1C')
sns.kdeplot(initial_cap[:,0],fill=False,linewidth=1,color='#27B2AF',label=None)
plt.hist(initial_cap[:,1],bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#752a80',edgecolor='black',label='2C')
sns.kdeplot(initial_cap[:,1],fill=False,linewidth=1,color='#752a80',label=None)
plt.hist(initial_cap[:,2],bins=10,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='C/5')
sns.kdeplot(initial_cap[:,2],fill=False,linewidth=1,color='#839DD1',label=None)
plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
plt.xlabel('Capacity [Ah]')
plt.ylabel('Density')
plt.show()


fig, axes = plt.subplots(1, 3, figsize=(20/2.54,6/2.54),dpi=600, sharey=True)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
axes[0].hist(cap_eight_percent_02C,bins=20,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='C/5')
sns.kdeplot(cap_eight_percent_02C, ax=axes[0], fill=False, linewidth=1, color='#839DD1', label=None)
axes[0].set_xlabel('Capacity [Ah]')
axes[1].hist(cap_sevenfive_percent_02C,bins=20,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='C/5')
sns.kdeplot(cap_sevenfive_percent_02C, ax=axes[1], fill=False, linewidth=1, color='#839DD1', label=None)
axes[1].set_xlabel('Capacity [Ah]')
axes[2].hist(cap_seven_percent_02C,bins=20,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='C/5')
sns.kdeplot(cap_seven_percent_02C, ax=axes[2], fill=False, linewidth=1, color='#839DD1', label=None)
axes[2].set_xlabel('Capacity [Ah]')
axes[0].set_ylabel('Density')
for ax in axes:
    ax.tick_params(axis='both', which='both', bottom=False, top=False, 
                   left=False, right=False)
plt.tight_layout()
plt.show()

print(np.mean(cap_eight_percent_02C), np.sqrt(np.var(cap_eight_percent_02C)))
print(np.mean(cap_sevenfive_percent_02C), np.sqrt(np.var(cap_sevenfive_percent_02C)))
print(np.mean(cap_seven_percent_02C), np.sqrt(np.var(cap_seven_percent_02C)))


#%%

fig, ax = plt.subplots(figsize=(4/2.54, 3/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(bottom=False, left=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
plt.hist(cap_eight_percent_1C, bins=18,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#27B2AF',edgecolor='black',label='1C')
sns.kdeplot(cap_eight_percent_1C,fill=False,linewidth=1,color='#27B2AF',label=None)
plt.hist(cap_eight_percent_2C,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#752a80',edgecolor='black',label='2C')
sns.kdeplot(cap_eight_percent_2C,fill=False,linewidth=1,color='#752a80',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
# plt.xlabel('Capacity [Ah]')
# plt.ylabel('Density')
plt.show()

fig, ax = plt.subplots(figsize=(4/2.54, 3/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.spines['top'].set_visible(True)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.set_label_position('top') 
ax.xaxis.tick_top()
ax.tick_params(bottom=False, left=False, top=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
plt.hist(EFC_eight_percent_02C,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='2C')
sns.kdeplot(EFC_eight_percent_02C,fill=False,linewidth=1,color='#839DD1',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
# plt.xlabel('Capacity [Ah]')
# plt.ylabel('Density')
ax.invert_yaxis()
plt.show()


fig, ax = plt.subplots(figsize=(4/2.54, 3/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(bottom=False, left=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
plt.hist(cap_sevenfive_percent_1C, bins=18,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#27B2AF',edgecolor='black',label='1C')
sns.kdeplot(cap_sevenfive_percent_1C,fill=False,linewidth=1,color='#27B2AF',label=None)
plt.hist(cap_sevenfive_percent_2C,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#752a80',edgecolor='black',label='2C')
sns.kdeplot(cap_sevenfive_percent_2C,fill=False,linewidth=1,color='#752a80',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
# plt.xlabel('Capacity [Ah]')
# plt.ylabel('Density')
plt.show()


fig, ax = plt.subplots(figsize=(4/2.54, 3/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.spines['top'].set_visible(True)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.set_label_position('top') 
ax.xaxis.tick_top()
ax.tick_params(bottom=False, left=False, top=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
plt.hist(EFC_sevenfive_percent_02C,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='2C')
sns.kdeplot(EFC_sevenfive_percent_02C,fill=False,linewidth=1,color='#839DD1',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
# plt.xlabel('Capacity [Ah]')
# plt.ylabel('Density')
ax.invert_yaxis()
plt.show()



fig, ax = plt.subplots(figsize=(4/2.54, 3/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.tick_params(bottom=False, left=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
plt.hist(cap_seven_percent_1C, bins=18,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#27B2AF',edgecolor='black',label='1C')
sns.kdeplot(cap_seven_percent_1C,fill=False,linewidth=1,color='#27B2AF',label=None)
plt.hist(cap_seven_percent_2C,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#752a80',edgecolor='black',label='2C')
sns.kdeplot(cap_seven_percent_2C,fill=False,linewidth=1,color='#752a80',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
# plt.xlabel('Capacity [Ah]')
# plt.ylabel('Density')
plt.show()


fig, ax = plt.subplots(figsize=(4/2.54, 3/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.spines['top'].set_visible(True)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.set_label_position('top') 
ax.xaxis.tick_top()
ax.tick_params(bottom=False, left=False, top=False)
ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
plt.hist(EFC_seven_percent_02C,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#839DD1',edgecolor='black',label='2C')
sns.kdeplot(EFC_seven_percent_02C,fill=False,linewidth=1,color='#839DD1',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
# plt.xlabel('Capacity [Ah]')
# plt.ylabel('Density')
ax.invert_yaxis()
plt.show()

#%%

def extract_and_interpolate(data, idx_seed, num_points=200):
    """
    Extract multiple 'rpt' samples for the 1C charging process (current, voltage, and charge capacity) and interpolate them into num_points data points as input.
    Also extract the 0.2C discharge process data (capacity-voltage) and interpolate it into num_points data points as output.
    Returns multiple samples with the shape (num_samples, num_points, 3).
    """
    nominal_capacity = 4.84
    inputs = []
    outputs = []
    start_v = []
    end_v = []
    cap_pari_1C = []
    cap_pari_2C = []
    cap_pari_work = []
    soc_range_1C = []
    soc_range_2C = []
    soc_range_work = []
    # print(data['protocol'])
    # Iterate over each 'rpt' and extract corresponding '1C', '2C' and '02C' data
    for rpt_data in data['rpt']:
        for round_idx in range(3):
            
            set_random_seed(idx_seed+round_idx)
            if 'work' in rpt_data and '02C' in rpt_data:
                # print('work')
                if rpt_data is data['rpt'][-1]:
                    # print('continue')
                    continue
                # Extract the 1C charging process
                charge_current = np.array(rpt_data['work']['current'])
                charge_voltage = np.array(rpt_data['work']['voltage'])
                # plt.plot(charge_voltage)
                # print(charge_voltage)
                charge_capacity = np.array(rpt_data['work']['charge_capacity'])
                EFCs = np.array(rpt_data['work']['EFC'])
                if charge_current.size < 50:
                    continue
                # Find the first point where voltage reaches 4.1V as the charging end point
                
                for charge_end_idx in range(2,len(charge_current)):
                    if charge_voltage[charge_end_idx]<charge_voltage[charge_end_idx-1]:
                        break
                
                # print(charge_end_idx)
                # start_index = 0
                # end_index = charge_end_idx-1
                
                # Randomly select start_index between 0 and 0.4 * charge_end_idx
                start_index = random.randint(0, int(0.3 * charge_end_idx))
                
                # Randomly select end_index between 0.8 * charge_end_idx and charge_end_idx
                end_index = random.randint(int(0.7 * charge_end_idx), charge_end_idx)
                
                
                # Extract charging data from start_index to end_index
                charge_current = charge_current[start_index:end_index] / nominal_capacity
                charge_voltage = charge_voltage[start_index:end_index] / 4.2
                charge_capacity = charge_capacity[start_index:end_index] / nominal_capacity-charge_capacity[start_index] / nominal_capacity
                EFCs = EFCs[start_index:end_index] / 500
                
                if charge_current.size<50 :
                    continue
                
                # Interpolation to obtain num_points data points
                x_new = np.linspace(0, 1, num_points)
                interp_current = interp1d(np.linspace(0, 1, len(charge_current)), charge_current, kind='linear')
                interp_voltage = interp1d(np.linspace(0, 1, len(charge_voltage)), charge_voltage, kind='linear')
                interp_capacity = interp1d(np.linspace(0, 1, len(charge_capacity)), charge_capacity, kind='linear')
                interp_EFCs = interp1d(np.linspace(0, 1, len(EFCs)), EFCs, kind='linear')
    
                charge_current_interp = interp_current(x_new)
                charge_voltage_interp = interp_voltage(x_new)
                charge_capacity_interp = interp_capacity(x_new)
                charge_EFCs_interp = interp_EFCs(x_new)
               
                # Extract the 0.2C discharge process (capacity and voltage) and perform interpolation
                discharge_capacity = np.array(rpt_data['02C']['discharge_capacity'])
                discharge_voltage = np.array(rpt_data['02C']['voltage'])
                
                dis_start_idx = np.where(discharge_capacity <=0)[0]
                if len(dis_start_idx) > 0:
                    dis_start_idx = dis_start_idx[-1] 
                else:
                    dis_start_idx = len(discharge_capacity)  
                
                discharge_capacity = discharge_capacity[dis_start_idx:]/nominal_capacity
                # print(max(discharge_capacity))
                discharge_voltage = discharge_voltage[dis_start_idx:]/ 4.2
                
                
                interp_ocv_v = interp1d(np.linspace(0, 1, len(discharge_voltage)), discharge_voltage, kind='linear')
                interp_ocv_q = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity, kind='linear')
                discharge_voltage_interp = interp_ocv_v(x_new)
                discharge_capacity_interp = interp_ocv_q(x_new)
                discharge_capacity_diff = discharge_capacity_interp[1:] - discharge_capacity_interp[ :-1]
                if max(discharge_capacity_diff)>0.05 or max(discharge_capacity)<0.7 or max(charge_capacity)*nominal_capacity<0.2: #  or max(discharge_capacity)<0.7
                    # print('skip')
                    continue
                
                start_v.append(charge_voltage[0]*4.2)
                end_v.append(charge_voltage[end_index-start_index-1]*4.2)
                cap_pari_work.append(np.stack([max(charge_capacity)*nominal_capacity, max(discharge_capacity)*nominal_capacity], axis=-1))
                soc_range_work.append(max(charge_capacity)/max(discharge_capacity))
                # print(max(charge_capacity)/max(discharge_capacity))
                # Generate 1C charging input (200, 3))
                input_data = np.stack([charge_voltage_interp, charge_capacity_interp, charge_EFCs_interp], axis=-1) #
                output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
    
                # Add the processed sample to the list
                inputs.append(input_data)
                outputs.append(output_data)    
            
    # Convert input and output lists to NumPy arrays
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)

    return inputs, outputs, start_v, end_v, cap_pari_1C, cap_pari_2C, cap_pari_work, soc_range_1C, soc_range_2C, soc_range_work


# Construct training and test datasets
inputs, outputs = [], []
start_v_all, end_v_all, cap_pari_1C_all, cap_pari_2C_all, cap_pari_work_all = [], [], [], [], []
soc_range_1C_all, soc_range_2C_all, soc_range_work_all = [], [], []
# Process the battery data in the training set
plt.figure(num=None,figsize=(8/2.54,4/2.54),dpi=600)
colors = ["#839DD1","#78A040"]
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)

plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

for i in range(0,len(all_batteries)): # train_batteries all_batteries 
    battery = all_batteries[i]
    print("Train cell", battery)
    cap = []
    input_data, output_data, start_v_data, end_v_data, cap_pari_1C_data,\
        cap_pari_2C_data, cap_pari_work_data,soc_range_1C_data, soc_range_2C_data,\
            soc_range_work_data = extract_and_interpolate(data_all_cells[battery],i)
        
    inputs.append(input_data)
    outputs.append(output_data)
    start_v_all.append(start_v_data)
    end_v_all.append(end_v_data)
    cap_pari_1C_all.append(cap_pari_1C_data)
    cap_pari_2C_all.append(cap_pari_2C_data)
    cap_pari_work_all.append(cap_pari_work_data)
    soc_range_1C_all.append(soc_range_1C_data)
    soc_range_2C_all.append(soc_range_2C_data)
    soc_range_work_all.append(soc_range_work_data)
    
    for j in range(len(output_data[:,-1,1])):
        if output_data[j,-1,1]<0.8 :
            break
        
    color_i = cmap(np.clip(np.max(input_data[j, -1, 2] * 500 / 1700), 0, 1))
    plt.plot(input_data[:,-1,2]*500,output_data[:,-1,1]*nominal_capacity, color=color_i)
plt.plot(np.arange(0,1800,100),0.8*nominal_capacity*np.ones(len(np.arange(0,1800,100))),'--', color='#F1766D')
plt.plot(np.arange(0,1800,100),0.75*nominal_capacity*np.ones(len(np.arange(0,1800,100))),'--', color='#F1766D')
plt.plot(np.arange(0,1800,100),0.7*nominal_capacity*np.ones(len(np.arange(0,1800,100))),'--', color='#F1766D')
plt.xlabel('EFCs')
plt.ylabel('Capacity [Ah]')
# plt.ylim([3.1, 4.9])
plt.show()

# Convert to NumPy arrays
inputs = np.concatenate(inputs, axis=0)
outputs = np.concatenate(outputs, axis=0)
start_v_all = np.concatenate(start_v_all, axis=0)
end_v_all = np.concatenate(end_v_all, axis=0)
cap_pari_1C_all = np.concatenate(cap_pari_1C_all, axis=0)
cap_pari_2C_all = np.concatenate(cap_pari_2C_all, axis=0)
cap_pari_work_all = np.concatenate(cap_pari_work_all, axis=0)
soc_range_1C_all = np.concatenate(soc_range_1C_all, axis=0)
soc_range_2C_all = np.concatenate(soc_range_2C_all, axis=0)
soc_range_work_all = np.concatenate(soc_range_work_all, axis=0)


colors = ["#839DD1","#F1766D"]
cmap2 = LinearSegmentedColormap.from_list("custom_cmap", colors)
plt.figure(figsize=(6/2.54, 5/2.54), dpi=600)
mpl.rcParams['figure.dpi'] = 600
ax=sns.jointplot(x=pd.Series(start_v_all),
              y=pd.Series(end_v_all),  # 
              cmap = cmap2,fill = True,thresh=0.025,
              # color = 'k',   # 
              # s = 50, edgecolor="w",linewidth=1,  # 
              kind = 'kde',   # “scatter”、“reg”、“resid”、“kde”、“hex”
              space = 0.1,  # 
              # size = 6,   # 
              ratio = 7,  # 
              marginal_kws=dict(rug=True, color='#E59F05', fill=True),
              # marginal_kws_y=dict(rug=True, color=[68/255, 189/255, 237/255], shade=True),  
              )
plt.tick_params( bottom=False, left=False) 
ax.fig.set_figwidth(6/2.54)
ax.fig.set_figheight(5/2.54)
ax.set_axis_labels("Start voltage [V]", "End voltage [V]",labelpad=1)
plt.xlim([2.8,3.75])
plt.ylim([3.4,4.3])
plt.show()

plt.figure(figsize=(6/ 2.54, 5 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top=True, right=True, which='both')
# plt.scatter(cap_pari_1C_all[:,1], cap_pari_1C_all[:,0], s=12, alpha=1, linewidth=0, color='#839DD1', label='C/5') #27B2AF
# plt.scatter(cap_pari_2C_all[:,1], cap_pari_2C_all[:,0], s=12, alpha=0.6, linewidth=0, color='#752a80', label='2C')
plt.scatter(cap_pari_work_all[:,1], cap_pari_work_all[:,0], s=12, alpha=0.6, linewidth=0, color='#F1766D', label='Cycling')
plt.ylabel('Measured capacity [Ah]')
plt.xlabel('C/5 RPT capacity [Ah]')
# plt.legend(frameon=False,labelspacing=0.01,loc='upper left')
plt.show()


fig, ax = plt.subplots(figsize=(6/2.54, 5/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.tick_params(bottom=False, left=False)
# plt.hist(100*soc_range_1C_all,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.8,color='#839DD1',edgecolor='black',label='C/5')
# sns.kdeplot(100*soc_range_1C_all,fill=False,linewidth=1,color='#27B2AF',label=None)
# plt.hist(soc_range_2C_all,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#752a80',edgecolor='black',label='2C')
# sns.kdeplot(soc_range_2C_all,fill=False,linewidth=1,color='#752a80',label=None)
plt.hist(100*soc_range_work_all,bins=25,fill=True,density=True,linewidth=0.5,alpha=0.7,color='#F1766D',edgecolor='black',label='Cycling')
sns.kdeplot(100*soc_range_work_all,fill=False,linewidth=1,color='#F1766D',label=None)
# plt.legend(frameon=False,labelspacing=0.01,loc='upper right')
plt.xlabel('SOC range [%]')
plt.ylabel('Density')
# plt.xlim([0,1.1])
plt.show()


plt.figure(figsize=(6/2.54, 5/2.54), dpi=600)
mpl.rcParams['figure.dpi'] = 600
ax=sns.jointplot(x=pd.Series(cap_pari_work_all[:,1]),
              y=pd.Series(cap_pari_work_all[:,0]),  # 
              cmap = cmap2,fill = True,thresh=0.025,
              # color = 'k',   # 
              # s = 50, edgecolor="w",linewidth=1,  # 
              kind = 'kde',   # “scatter”、“reg”、“resid”、“kde”、“hex”
              space = 0.1,  # 
              # size = 6,   # 
              ratio = 7,  # 
              marginal_kws=dict(rug=True, color='#E59F05', fill=True),
              # marginal_kws_y=dict(rug=True, color=[68/255, 189/255, 237/255], shade=True),  
              )
plt.tick_params( bottom=False, left=False) 
ax.fig.set_figwidth(6/2.54)
ax.fig.set_figheight(5/2.54)
ax.set_axis_labels("C/5 RPT capacity [Ah]", "Measured capacity [Ah]",labelpad=1)
# plt.xlim([2.8,3.75])
# plt.ylim([3.4,4.3])
plt.show()

#%%

folder_loc = '..\\ResValData'
folder_loc = os.path.abspath(folder_loc)
file_list = [f for f in os.listdir(folder_loc) if f.startswith('ResVal')]
Cap_all =[]
for i, battery in enumerate(file_list):
    file_name = os.path.join(folder_loc, battery)
    print(battery)
    datapath = structure.MaccorDatapath.from_file(file_name)
    all_cycle_types = [
            "start_discharge",
            "C/80_Cycle",
            "GITT",
            "C/40_Cycle",
            "0.05A_Cycle_mistake",
            "C/10_Cycle",
            "C/7_cycle",
            "C/5_Cycle",
            "1C_Cycle",
            "2C_Cycle",
            "charge_for_storage",
    ]
    datapath.raw_data["cycle_type"] = datapath.raw_data["cycle_index"].apply(lambda x: all_cycle_types[x])
    
    data=datapath.raw_data
    Cap_all.append(np.hstack((np.max(data[data["cycle_type"] == "C/80_Cycle"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "C/40_Cycle"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "0.05A_Cycle_mistake"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "C/10_Cycle"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "C/7_cycle"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "C/5_Cycle"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "1C_Cycle"]['discharge_capacity']),
                             np.max(data[data["cycle_type"] == "2C_Cycle"]['discharge_capacity'])
                             )))
    
    if i==0:
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12 / 2.54, 6 / 2.54), dpi=600, 
                                       gridspec_kw={'hspace': 0})  # 'hspace=0' 
        ax1.plot(np.array(data[data["cycle_type"] == "start_discharge"]['test_time']), np.array(data[data["cycle_type"] == "start_discharge"]['current']), color='grey')
        ax1.plot(np.array(data[data["cycle_type"] == "C/80_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/80_Cycle"]['current']), color='#F1766D')
        ax1.plot(np.array(data[data["cycle_type"] == "GITT"]['test_time']), np.array(data[data["cycle_type"] == "GITT"]['current']), color='#E59F05')
        ax1.plot(np.array(data[data["cycle_type"] == "C/40_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/40_Cycle"]['current']), color='#f3a17c')
        ax1.plot(np.array(data[data["cycle_type"] == "0.05A_Cycle_mistake"]['test_time']), np.array(data[data["cycle_type"] == "0.05A_Cycle_mistake"]['current']), color='#ecc68c')
        ax1.plot(np.array(data[data["cycle_type"] == "C/10_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/10_Cycle"]['current']), color='#b3c6bb')
        ax1.plot(np.array(data[data["cycle_type"] == "C/7_cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/7_cycle"]['current']), color='#78A040')
        ax1.plot(np.array(data[data["cycle_type"] == "C/5_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/5_Cycle"]['current']), color='#839DD1')
        ax1.plot(np.array(data[data["cycle_type"] == "1C_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "1C_Cycle"]['current']), color='#27B2AF')
        ax1.plot(np.array(data[data["cycle_type"] == "2C_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "2C_Cycle"]['current']), color='#752a80')
        ax1.plot(np.array(data[data["cycle_type"] == "charge_for_storage"]['test_time']), np.array(data[data["cycle_type"] == "charge_for_storage"]['current']), color='grey')
        ax1.set_ylabel('Current [A]')
        ax1.tick_params(top=True, right=True, which='both', direction='in')
        ax1.spines['bottom'].set_visible(False)  # 
        ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        ax1.set_xticks([])

        ax2.plot(np.array(data[data["cycle_type"] == "start_discharge"]['test_time']), np.array(data[data["cycle_type"] == "start_discharge"]['voltage']), color='grey')
        ax2.plot(np.array(data[data["cycle_type"] == "C/80_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/80_Cycle"]['voltage']), color='#F1766D')
        ax2.plot(np.array(data[data["cycle_type"] == "GITT"]['test_time']), np.array(data[data["cycle_type"] == "GITT"]['voltage']), color='#E59F05')
        ax2.plot(np.array(data[data["cycle_type"] == "C/40_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/40_Cycle"]['voltage']), color='#f3a17c')
        ax2.plot(np.array(data[data["cycle_type"] == "0.05A_Cycle_mistake"]['test_time']), np.array(data[data["cycle_type"] == "0.05A_Cycle_mistake"]['voltage']), color='#ecc68c')
        ax2.plot(np.array(data[data["cycle_type"] == "C/10_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/10_Cycle"]['voltage']), color='#b3c6bb')
        ax2.plot(np.array(data[data["cycle_type"] == "C/7_cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/7_cycle"]['voltage']), color='#78A040')
        ax2.plot(np.array(data[data["cycle_type"] == "C/5_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "C/5_Cycle"]['voltage']), color='#839DD1')
        ax2.plot(np.array(data[data["cycle_type"] == "1C_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "1C_Cycle"]['voltage']), color='#27B2AF')
        ax2.plot(np.array(data[data["cycle_type"] == "2C_Cycle"]['test_time']), np.array(data[data["cycle_type"] == "2C_Cycle"]['voltage']), color='#752a80')
        ax2.plot(np.array(data[data["cycle_type"] == "charge_for_storage"]['test_time']), np.array(data[data["cycle_type"] == "charge_for_storage"]['voltage']), color='grey')
        
        ax2.set_xlabel('Time [s]')
        ax2.set_ylabel('Voltage [V]')
        ax2.tick_params(top=True, right=True, which='both', direction='in')
        ax2.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        fig.align_labels() 
        plt.show()


Cap_all_matrix = np.array(Cap_all)
Cap_all_matrix = np.where(np.isnan(Cap_all_matrix), 
                          0, 
                          Cap_all_matrix)

column_labels = ["C/80", "C/40", "0.05 A", "C/10", "C/7", "C/5", "1C", "2C"]
colors = ["#F1766D", "#f3a17c", "#ecc68c", "#b3c6bb", "#78A040", "#839DD1", "#27B2AF", "#752a80"]

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
fig, ax = plt.subplots(figsize=(8.5/2.54, 5.5/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.tick_params(bottom=False, left=False)
parts = plt.violinplot(Cap_all_matrix, showmeans=True)


for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i % len(colors)])
    pc.set_edgecolor(colors[i % len(colors)])
    pc.set_alpha(0.6)
parts['cmeans'].set_color(colors)
parts['cbars'].set_color(colors)
parts['cmins'].set_color(colors)
parts['cmaxes'].set_color(colors)    
for i in range(Cap_all_matrix.shape[1]):
    y = Cap_all_matrix[:, i]
    x = np.random.normal(i + 1, 0.05, size=len(y))  # 
    plt.scatter(x, y, color=colors[i], alpha=0.9, s=5, edgecolor='none')
    
plt.ylabel("Capacity [Ah]")
plt.xticks(range(1, len(column_labels) + 1), column_labels, rotation=90) 
ax_inset = inset_axes(ax, width="50%", height="40%", 
                      bbox_to_anchor=(-0.2, -0.25, 0.8, 0.8),  # 左下角
                      bbox_transform=ax.transAxes)

c05_data = Cap_all_matrix[:, 5] / 4.84
ax_inset.tick_params(bottom=False, left=False)
ax_inset.hist(c05_data, bins=20, color=colors[5], alpha=0.8)
# ax_inset.set_xlabel('C/5 SOH', labelpad=-2.)
ax_inset.yaxis.set_ticklabels([]) 
ax_inset.tick_params(axis='x', labelsize=10.5)
ax_inset.set_ylabel('Count', labelpad=-2., fontsize=10.5)
plt.show()


fig, ax = plt.subplots(figsize=(6/2.54,5/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.tick_params(bottom=False, left=False)
ax.hist(c05_data, bins=20, color=colors[5], alpha=0.8)
sns.kdeplot(c05_data,fill=False,linewidth=1.5,color=colors[5],label=None)
plt.ylabel("Count")
plt.xlabel('SOH @ C/5')
plt.show()



fig, ax = plt.subplots(figsize=(8.5/2.54, 5.5/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.tick_params(bottom=False, left=False)
parts = plt.violinplot(Cap_all_matrix, showmeans=True)


for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i % len(colors)])
    pc.set_edgecolor(colors[i % len(colors)])
    pc.set_alpha(0.6)
parts['cmeans'].set_color(colors)
parts['cbars'].set_color(colors)
parts['cmins'].set_color(colors)
parts['cmaxes'].set_color(colors)    
for i in range(Cap_all_matrix.shape[1]):
    y = Cap_all_matrix[:, i]
    x = np.random.normal(i + 1, 0.05, size=len(y))  # 
    plt.scatter(x, y, color=colors[i], alpha=0.9, s=5, edgecolor='none')
    
plt.ylabel("Capacity [Ah]")
plt.xticks(range(1, len(column_labels) + 1), column_labels, rotation=90) 
plt.ylim([2,5.1])
plt.show()



fig, ax = plt.subplots(figsize=(8/2.54, 6/2.54), dpi=600)
# ax.tick_params(top='on', right='on', which='both')
ax.tick_params(bottom=False, left=False)
positions = range(1, Cap_all_matrix.shape[1] + 1)  
for i in range(Cap_all_matrix.shape[1]):
    plt.plot(Cap_all_matrix[:, i], color=colors[i], label=column_labels[i],alpha=0.8)
    
plt.ylabel("Capacity [Ah]")
plt.xlabel("Cells")
plt.show()




#%%

folder_loc = '..\\dynamicdata'
folder_loc = os.path.abspath(folder_loc)
file_list = [f for f in os.listdir(folder_loc)]
Cap_all =[]
EFC_all =[]
plt.figure(num=None,figsize=(8/2.54,4/2.54),dpi=600)
colors = ["#839DD1","#78A040"]
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)

plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)


for i, battery in enumerate(file_list):
    file_name = os.path.join(folder_loc, battery)
    print(battery)
    dynamic_data = pd.read_csv(file_name, low_memory=False)
    Capacity_all = dynamic_data['Normalized Capacity (nominal capacity unit)'].values
    Voltage_all = dynamic_data['Volts'].values
    EFC_all = np.zeros_like(Capacity_all)  
    cumulative_offset = 0  
    for i in range(len(Capacity_all)):
        if abs(Capacity_all[i])>1000:
            Capacity_all[i] = Capacity_all[i-1]
            
        if i > 0 and abs(Capacity_all[i]) < abs(Capacity_all[i - 1]):  
            cumulative_offset += abs(Capacity_all[i - 1])  
        EFC_all[i] = abs(Capacity_all[i]) + cumulative_offset  
    EFC_all = EFC_all / 2  
    
    grouped = dynamic_data.groupby("Cyc#")
    Voltage_cycles = [group["Volts"].values for _, group in grouped]
    Capacity_cycles = [group["Normalized Capacity (nominal capacity unit)"].values for _, group in grouped]
    Step_cycles = [group["Step"].values for _, group in grouped]
    Current_cycles = [group["Normalized Current (C-rate)"].values for _, group in grouped]
    Time_cycles = [group["Test (Sec)"].values for _, group in grouped]
    EFC_cycles = [EFC_all[group.index] for _, group in grouped]
    
    selected_cycles = []
    for i in range(len(Time_cycles)-1):
        if len(Time_cycles[i]) > 20000 and (Time_cycles[i][-1] - Time_cycles[i][0]) > 200000 and len(Time_cycles[i+1]) > 2000:
            selected_cycles.append(i)
    cap_curve = []
    EFC_curve = []
    for rpt_data_cycle in selected_cycles[:-1]:

      
        I = np.array(Current_cycles[rpt_data_cycle])
        condition = (I == -0.025)
        indices = np.where(condition)[0]
        C040_segment = []
        current_segment = [indices[0]] if len(indices) > 0 else []
        for i in range(1, len(indices)):
            if indices[i] == indices[i - 1] + 1: 
                current_segment.append(indices[i])
            else:
              
                if len(current_segment) > len(C040_segment):
                    C040_segment = current_segment
                current_segment = [indices[i]]
        if len(current_segment) > len(C040_segment):
            C040_segment = current_segment
        condition = (np.array(Step_cycles[rpt_data_cycle]) == np.array(Step_cycles[rpt_data_cycle][C040_segment[0]]))
        C040_segment = np.where(condition)[0]
        
        for i in range(len(C040_segment)):
            # print(i)
            if np.array(Current_cycles[rpt_data_cycle])[C040_segment[i]]<0 and np.array(Capacity_cycles[rpt_data_cycle])[C040_segment[i]]>0:
                break
        
        for j in range(i,len(C040_segment)-1):
            if np.array(Current_cycles[rpt_data_cycle])[C040_segment[j]]>=0 and np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[j]]<=2.801:
                break
            
        for k in range(i,j):
            if np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k+1]]>np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k]] and np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k]]<=2.801:
                break
        
        
        C040_segment = C040_segment[i:k]
        discharge_capacity = np.array(Capacity_cycles[rpt_data_cycle])[C040_segment]
        discharge_voltage = np.array(Voltage_cycles[rpt_data_cycle])[C040_segment]/4.2
        discharge_efc = np.array(EFC_cycles[rpt_data_cycle])[C040_segment]
        cap_curve.append(discharge_capacity[-1])
        EFC_curve.append(discharge_efc[-1])
        
    color_i = cmap(np.clip(np.max(np.array(EFC_curve) / 2500), 0, 1))
    
    plt.plot(np.array(EFC_curve),np.array(cap_curve), color=color_i)
    # plt.plot(np.arange(0,1800,100),0.85*nominal_capacity*np.ones(len(np.arange(0,1800,100))),'--', color='#F1766D')
    # plt.plot(np.arange(0,1800,100),0.75*nominal_capacity*np.ones(len(np.arange(0,1800,100))),'--', color='#F1766D')
    # plt.plot(np.arange(0,1800,100),0.7*nominal_capacity*np.ones(len(np.arange(0,1800,100))),'--', color='#F1766D')
    plt.xlabel('EFCs')
    plt.ylabel('SOH')
    # plt.ylim([3.1, 4.9])
    

    Cap_all.append(np.array(cap_curve))
    
    if battery =='Publishing_data_raw_data_cell_095.csv':
        
        time_show = [Time_cycles[selected_cycles[1]-2],Time_cycles[selected_cycles[1]-1],Time_cycles[selected_cycles[1]],Time_cycles[selected_cycles[1]+1]]
        voltage_show = [Voltage_cycles[selected_cycles[1]-2],Voltage_cycles[selected_cycles[1]-1],Voltage_cycles[selected_cycles[1]],Voltage_cycles[selected_cycles[1]+1]]
        current_show = [Current_cycles[selected_cycles[1]-2],Current_cycles[selected_cycles[1]-1],Current_cycles[selected_cycles[1]],Current_cycles[selected_cycles[1]+1]]
        step_show = [Step_cycles[selected_cycles[1]-2],Step_cycles[selected_cycles[1]-1],Step_cycles[selected_cycles[1]],Step_cycles[selected_cycles[1]+1]]
        idx_end_dis = []
        time_show=np.concatenate(time_show)
        voltage_show=np.concatenate(voltage_show)
        current_show=np.concatenate(current_show)
        step_show= np.concatenate(step_show)
        for i in range(1,len(time_show)-10):
            if step_show[i+1]!=step_show[i]:#voltage_show[i]<2.8 and voltage_show[i+1]>voltage_show[i] and voltage_show[i-1]>voltage_show[i] :
                idx_end_dis.append(i)
        idx_end_dis = np.array(idx_end_dis)
    
    
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12 / 2.54, 6 / 2.54), dpi=600, 
                                       gridspec_kw={'hspace': 0})  # 'hspace=0' 
        ax1.plot(time_show, current_show, color='#F1766D')
        ax1.plot(time_show[idx_end_dis[9]:idx_end_dis[11]], current_show[idx_end_dis[9]:idx_end_dis[11]], label='Voltage', color='#27B2AF')
        ax1.plot(time_show[idx_end_dis[11]:idx_end_dis[38]], current_show[idx_end_dis[11]:idx_end_dis[38]], label='Voltage', color='#78A040')
        ax1.plot(time_show[idx_end_dis[38]:idx_end_dis[43]], current_show[idx_end_dis[38]:idx_end_dis[43]], label='Voltage', color='#27B2AF')
        ax1.plot(time_show[idx_end_dis[43]:idx_end_dis[78]], current_show[idx_end_dis[43]:idx_end_dis[78]], label='Voltage', color='#E59F05')
        ax1.plot(time_show[idx_end_dis[78]:idx_end_dis[84]], current_show[idx_end_dis[78]:idx_end_dis[84]], label='Voltage', color='#839DD1')
        ax1.plot(time_show[idx_end_dis[84]:idx_end_dis[89]], current_show[idx_end_dis[84]:idx_end_dis[89]], label='Voltage', color='#752a80')
        ax1.set_ylabel('Current [C]')
        ax1.tick_params(top=True, right=True, which='both', direction='in')
        ax1.spines['bottom'].set_visible(False)  # 
        ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        ax1.set_xticks([])
    
        ax2.plot(time_show, voltage_show, color='#F1766D')
        ax2.plot(time_show[idx_end_dis[9]:idx_end_dis[11]], voltage_show[idx_end_dis[9]:idx_end_dis[11]], label='Voltage', color='#27B2AF')
        ax2.plot(time_show[idx_end_dis[11]:idx_end_dis[38]], voltage_show[idx_end_dis[11]:idx_end_dis[38]], label='Voltage', color='#78A040')
        ax2.plot(time_show[idx_end_dis[38]:idx_end_dis[43]], voltage_show[idx_end_dis[38]:idx_end_dis[43]], label='Voltage', color='#27B2AF')
        ax2.plot(time_show[idx_end_dis[43]:idx_end_dis[78]], voltage_show[idx_end_dis[43]:idx_end_dis[78]], label='Voltage', color='#E59F05')
        ax2.plot(time_show[idx_end_dis[78]:idx_end_dis[84]], voltage_show[idx_end_dis[78]:idx_end_dis[84]], label='Voltage', color='#839DD1')
        ax2.plot(time_show[idx_end_dis[84]:idx_end_dis[88]], voltage_show[idx_end_dis[84]:idx_end_dis[88]], label='Voltage', color='#752a80')
        ax2.set_xlabel('Time [s]')
        ax2.set_ylabel('Voltage [V]')
        ax2.tick_params(top=True, right=True, which='both', direction='in')
        ax2.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
        fig.align_labels() 
        plt.show()

plt.show()

    
    #%%
    
    




