import csv
import matplotlib.pyplot as plt
import numpy as np
import tikzplotlib

wgnum = 3

# Define the mapping between wgnum and wgexp
wgexp_dict = {1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 8, 7: 11, 8: 2, 9: 9, 10: 10}

# Define variables to store x and y values for each dp value
x_sph_dp1 = []
y_sph_dp1 = []
x_sph_dp2 = []
y_sph_dp2 = []
x_sph_dp3 = []
y_sph_dp3 = []
x_sph_dp4 = []
y_sph_dp4 = []
x_sph_dp5 = []
y_sph_dp5 = []

# Use the dictionary to get the corresponding wgexp value
wgexp = wgexp_dict[wgnum]

# # SPH data (in meters)
# with open('wg0.2.csv', newline='') as csvfile:
#     csvreader = csv.reader(csvfile, delimiter=';')
#     next(csvreader)  # skip first line
#     next(csvreader)  # skip second line
#     next(csvreader)  # skip third line
#     next(csvreader)  # skip fourth line
#     for row in csvreader:
#         x_val = float(row[1])
#         if x_val >= 2 and x_val <= 12:
#             x_sph_dp1.append(float(row[1]))
#             y_sph_dp1.append(float(row[wgexp]) * 100 - 75)  # convert from meters to centimeters

# SPH data (in meters)
with open('wg0.1.csv', newline='') as csvfile:
    csvreader = csv.reader(csvfile, delimiter=';')
    next(csvreader)  # skip first line
    next(csvreader)  # skip second line
    next(csvreader)  # skip third line
    next(csvreader)  # skip fourth line
    for row in csvreader:
        x_val = float(row[1])
        if x_val >= 2 and x_val <= 12:
            x_sph_dp2.append(float(row[1]))
            y_sph_dp2.append(float(row[wgexp]) * 100 - 75)  # convert from meters to centimeters

# # SPH data (in meters)
# with open('wg0.05.csv', newline='') as csvfile:
#     csvreader = csv.reader(csvfile, delimiter=';')
#     next(csvreader)  # skip first line
#     next(csvreader)  # skip second line
#     next(csvreader)  # skip third line
#     next(csvreader)  # skip fourth line
#     for row in csvreader:
#         x_val = float(row[1])
#         if x_val >= 2 and x_val <= 12:
#             x_sph_dp3.append(float(row[1]))
#             y_sph_dp3.append(float(row[wgexp]) * 100 - 75)  # convert from meters to centimeters


# SPH data (in meters)
with open('wg0.025.csv', newline='') as csvfile:
    csvreader = csv.reader(csvfile, delimiter=';')
    next(csvreader)  # skip first line
    next(csvreader)  # skip second line
    next(csvreader)  # skip third line
    next(csvreader)  # skip fourth line
    for row in csvreader:
        x_val = float(row[1])
        if x_val >= 2 and x_val <= 12:
            x_sph_dp4.append(float(row[1]))
            y_sph_dp4.append(float(row[wgexp]) * 100 - 75)  # convert from meters to centimeters
            
# SPH data (in meters)
with open('wg0.0125.csv', newline='') as csvfile:
    csvreader = csv.reader(csvfile, delimiter=';')
    next(csvreader)  # skip first line
    next(csvreader)  # skip second line
    next(csvreader)  # skip third line
    next(csvreader)  # skip fourth line
    for row in csvreader:
        x_val = float(row[1])
        if x_val >= 2 and x_val <= 12:
            x_sph_dp5.append(float(row[1]))
            y_sph_dp5.append(float(row[wgexp]) * 100 - 75)  # convert from meters to centimeters


# Experimental data (in centimeters)
x_exp = []
y_exp = []
with open('wg'+str(wgnum)+'_exp.csv', newline='') as csvfile:
    csvreader = csv.reader(csvfile, delimiter=';')
    for row in csvreader:
        x_exp.append(float(row[0]))
        y_exp.append(float(row[1]))
x_exp, y_exp = zip(*sorted(zip(x_exp, y_exp)))


### dp2 for 0.1, dp4 for 0.025, dp5 for 0.0125

# Calculate the error in peak free surface elevation
error = ((max(y_exp) - max(y_sph_dp2))/max(y_exp))*100
print("Error in peak free surface elevation: {:.3f} %".format(error))


# Find the index of the maximum value in the experimental data
max_index_exp = y_exp.index(max(y_exp))
arrival_time_exp = x_exp[max_index_exp]

# Find the index of the maximum value in the SPH data (dp = 0.1)
max_index_sph2 = y_sph_dp2.index(max(y_sph_dp2))
arrival_time_sph2 = x_sph_dp2[max_index_sph2]

max_index_sph4 = y_sph_dp4.index(max(y_sph_dp4))
arrival_time_sph4 = x_sph_dp4[max_index_sph4]

max_index_sph5 = y_sph_dp5.index(max(y_sph_dp5))
arrival_time_sph5 = x_sph_dp5[max_index_sph5]

# Calculate the arrival time error
arrival_time_error2 = ((arrival_time_exp - arrival_time_sph2)/arrival_time_exp)*100
print("Arrival time error: {:.2f} %".format(arrival_time_error2))



# Calculate the error for each wave gauge
errors = [
    ((max(y_exp) - max(y_sph_dp2)) / max(y_exp)) * 100,
    ((max(y_exp) - max(y_sph_dp4)) / max(y_exp)) * 100,
    ((max(y_exp) - max(y_sph_dp5)) / max(y_exp)) * 100
]


# Calculate the arrival time error for each wave gauge
arrival_time_errors = [
    ((arrival_time_exp - arrival_time_sph2) / arrival_time_exp) * 100,
    ((arrival_time_exp - arrival_time_sph4) / arrival_time_exp) * 100,
    ((arrival_time_exp - arrival_time_sph5) / arrival_time_exp) * 100
]

# Define the labels for each wave gauge
wave_gauges = ['0.1', '0.025', '0.0125']

# # Combine the two plots into a single figure
# plt.figure(figsize=(12, 6))

# Plot the error in peak free surface elevation
bar_width = 0.35
bar_positions1 = np.arange(len(wave_gauges))
bar_positions2 = bar_positions1 + bar_width

plt.bar(bar_positions1, errors, bar_width, label='Elevation error (%)')
plt.bar(bar_positions2, arrival_time_errors, bar_width, label='Arrival time error (%)', color='orange')

plt.xlabel('dp')
plt.ylabel('Error (%)')
# plt.title('Error in Peak Free Surface Elevation and Wave Arrival Time for Each Wave Gauge')
plt.xticks(bar_positions1 + bar_width / 2, wave_gauges)
# plt.legend()

# Set the y-axis limits from -8 to 8
plt.ylim(-15, 15)

# Save the plot as a TikZ file
tikzplotlib.save('wg' +str(wgnum)+ 'c_error.tex')

# Save the figure as a PDF
plt.savefig('wg' +str(wgnum)+ 'c_error.pdf', format='pdf', bbox_inches='tight', pad_inches=0)


# Show the combined plot
plt.show()



# Create a separate legend plot
# plt.figure(figsize=(8, 2))
plt.axis('off')  # Turn off axis and ticks

# Add legend entries
plt.plot([], [], color='blue', label='Elevation error (%)')
plt.plot([], [], color='orange', label='Arrival time error (%)')

# Create a horizontal legend
plt.legend(loc='center', ncol=2)

# Save the legend as a separate PDF
plt.savefig('wglegend.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

# Show the combined plot
plt.show()