In [1]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
import netCDF4 as nc
import pandas as pd
import os
import csv
from glob import glob
import xarray as xr
import matplotlib.ticker as ticker 

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap, Normalize
In [2]:
# Step 1: Data Preparation
# Load your dataset into a pandas DataFrame
os.chdir('/Users/chenchenren/postdoc/paper/2N and water-US/Figure 1/')
In [3]:
excel_file = "data_ok240115.xlsx"
sheet_name = "maize"
data = pd.read_excel(excel_file, sheet_name=sheet_name)

# Create interaction terms
data['tmp_tmp_interaction'] = data['tmp'] * data['tmp']
data['irrigation12'] = data['irrigation1'] * data['irrigation1']
data['lnfer_irrigation1_tmp_interaction'] = data['lnfer'] * data['irrigation1'] * data['tmp']
data['lnfer_irrigation12_tmp_tmp_interaction'] = data['lnfer']* data['irrigation1']  * data['irrigation1'] * data['tmp'] * data['tmp']

data['lnfer_tmp_interaction'] = data['lnfer'] * data['tmp']
data['lnfer_tmp_tmp_interaction'] = data['lnfer'] * data['tmp'] * data['tmp']

# Define predictor variables and response variable
predictor_variables1 = ['lnfer', 'tmp', 'irrigation1', 'irrigation12', 'tmp_tmp_interaction',
                       'lnfer_irrigation1_tmp_interaction', 'lnfer_irrigation12_tmp_tmp_interaction']

predictor_variables2 = ['lnfer', 'tmp', 'tmp_tmp_interaction','irrigation1',
                       'lnfer_tmp_interaction', 'lnfer_tmp_tmp_interaction']

response_variable = 'lnyield'

# Get unique group values
group_values = data['group'].unique()

# Create a figure with a grid layout
fig = plt.figure(figsize=(10, 3.5))
gs = GridSpec(1, 3, wspace=0.35, width_ratios=[4, 4, 1.2])  # Create a 1x3 grid for main plots and color map

# Define a custom gradient color map
cmap_main = plt.get_cmap('RdYlGn')
norm_main = Normalize(5, 10)  # Adjust the normalization range for the main plots color map
colors = []
values = []

for group_idx, group in enumerate(group_values):
    # Filter data for the current group
    group_data = data[data['group'] == group].copy()  # Make a copy to avoid warnings
    
    # Select predictor variables based on the group
    predictor_variables = predictor_variables1 if group == 1 else predictor_variables2
    
    # Perform regression analysis for the current group
    X = sm.add_constant(group_data[predictor_variables])
    y = group_data[response_variable]
    reg = sm.OLS(y, X).fit()
    
    x_temp = np.linspace(10, 30, 100)
    lnfer_range = np.linspace(4.4, 6.9, 100)
    irrigation1_range = np.linspace(0, 4, 100)
    
    # Update the DataFrame using .loc to avoid warnings
    group_data.loc[:, 'tmp_tmp_interaction'] = group_data['tmp'] * group_data['tmp']
    group_data.loc[:, 'lnfer_irrigation1_tmp_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['tmp']
    group_data.loc[:, 'lnfer_irrigation12_tmp_tmp_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['irrigation1']* group_data['tmp'] * group_data['tmp']
    
    if group != 1:
        group_data.loc[:, 'lnfer_tmp_interaction'] = group_data['lnfer'] * group_data['tmp']
        group_data.loc[:, 'lnfer_tmp_tmp_interaction'] = group_data['lnfer'] * group_data['tmp'] * group_data['tmp']
    
    for i in range(len(irrigation1_range)):
        irrigation1_value = irrigation1_range[i]
        
        # Only plot lines where irrigation is equal to 0 in group 2
        if group == 2 and irrigation1_value != 0:
            continue
        
        for j in range(len(lnfer_range)):
            lnfer_value = lnfer_range[j]
            
            if group == 1:
                y_temp = (reg.params['tmp_tmp_interaction'] * x_temp**2 +
                          reg.params['tmp'] * x_temp +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['irrigation12'] * irrigation1_value**2 +
                          reg.params['lnfer_irrigation1_tmp_interaction'] * irrigation1_value * x_temp * lnfer_value +
                          reg.params['lnfer_irrigation12_tmp_tmp_interaction'] * irrigation1_value**2 * lnfer_value * x_temp**2 +
                          reg.params['const'])
            else:
                y_temp = (reg.params['tmp_tmp_interaction'] * x_temp**2 +
                          reg.params['tmp'] * x_temp +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['lnfer_tmp_interaction'] * x_temp * lnfer_value +
                          reg.params['lnfer_tmp_tmp_interaction'] * lnfer_value * x_temp**2 +
                          reg.params['const'])
            # Create a subplot for each group
            ax = plt.subplot(gs[group_idx])
            
            # Calculate the color based on the normalization
            color = cmap_main(norm_main(lnfer_value + irrigation1_value))
            
            ax.plot(x_temp, y_temp, color=color, linewidth=2, alpha=0.8)

            # Collect the color and its corresponding values
            colors.append(color)
            values.append((irrigation1_value, lnfer_value))  # Store the corresponding values

            
            # Set common labels and ticks for x and y axes
            ax.set_xlabel('Air temperature (°C)',  fontfamily='Arial',fontsize=16)
            ax.set_ylabel('Ln Yield (kg/ha/yr)', fontfamily='Arial', fontsize=16, labelpad=-3)
            x_ticks = np.arange(10, 31, 5)
            y_ticks = np.arange(-1.5, 5, 2)
            ax.set_xticks(x_ticks)
            ax.set_yticks(y_ticks)
            ax.set_yticklabels(ax.get_yticks(), fontfamily='Arial',fontsize=14)
            ax.set_xticklabels(ax.get_xticks(), fontfamily='Arial',fontsize=14)

# Create a subplot for the color map
ax_color_map = plt.subplot(gs[2])

# Create a scatter plot for the color map subplot
scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)

# Set common labels and ticks for x and y axes in the color map subplot
ax_color_map.set_xlabel('Irrigation\n($10^{3}$m$^{3}$/ha/yr)', fontfamily='Arial',fontsize=12)
ax_color_map.set_ylabel('Ln N input (kg/ha/yr)', fontfamily='Arial',fontsize=14, x=0.98)
ax_color_map.set_xticks(np.arange(0, 5, 2))
ax_color_map.set_xticklabels([f'{round(tick, 1)}' for tick in ax_color_map.get_xticks()], fontfamily='Arial',fontsize=14)
ax_color_map.set_yticks(np.arange(4.4, 7.0, 0.5))
ax_color_map.set_yticklabels(ax_color_map.get_yticks(),fontfamily='Arial', fontsize=14)


# Save the entire figure as a JPG with DPI 300
plt.savefig('maize_tmp.jpg', format='jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)

# Show the plot
plt.show()
/var/folders/vd/0_phd7hx2n51y4412862zww00000gp/T/ipykernel_7375/1152588395.py:115: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)
In [4]:
excel_file = "data_ok240115.xlsx"
sheet_name = "soybean"
data = pd.read_excel(excel_file, sheet_name=sheet_name)

# Create interaction terms
data['tmp_tmp_interaction'] = data['tmp'] * data['tmp']
data['irrigation12'] = data['irrigation1'] * data['irrigation1']
data['lnfer_irrigation1_tmp_interaction'] = data['lnfer'] * data['irrigation1'] * data['tmp']
data['lnfer_irrigation12_tmp_tmp_interaction'] = data['lnfer']* data['irrigation1']  * data['irrigation1'] * data['tmp'] * data['tmp']

data['lnfer_tmp_interaction'] = data['lnfer'] * data['tmp']
data['lnfer_tmp_tmp_interaction'] = data['lnfer'] * data['tmp'] * data['tmp']

# Define predictor variables and response variable
predictor_variables1 = ['lnfer', 'tmp', 'irrigation1', 'irrigation12', 'tmp_tmp_interaction',
                       'lnfer_irrigation1_tmp_interaction', 'lnfer_irrigation12_tmp_tmp_interaction']

predictor_variables2 = ['lnfer', 'tmp', 'tmp_tmp_interaction','irrigation1',
                       'lnfer_tmp_interaction', 'lnfer_tmp_tmp_interaction']

response_variable = 'lnyield'

# Get unique group values
group_values = data['group'].unique()

# Create a figure with a grid layout
fig = plt.figure(figsize=(10, 3.5))
gs = GridSpec(1, 3, wspace=0.35, width_ratios=[4, 4, 1.2])  # Create a 1x3 grid for main plots and color map

# Define a custom gradient color map
cmap_main = plt.get_cmap('PuOr')
norm_main = Normalize(1, 5.3)  # Adjust the normalization range for the main plots color map
colors = []
values = []

for group_idx, group in enumerate(group_values):
    # Filter data for the current group
    group_data = data[data['group'] == group].copy()  # Make a copy to avoid warnings
    
    # Select predictor variables based on the group
    predictor_variables = predictor_variables1 if group == 1 else predictor_variables2
    
    # Perform regression analysis for the current group
    X = sm.add_constant(group_data[predictor_variables])
    y = group_data[response_variable]
    reg = sm.OLS(y, X).fit()
    
    x_temp = np.linspace(14, 30, 100)
    lnfer_range = np.linspace(0.3, 4.3, 100)
    irrigation1_range = np.linspace(0, 2, 100)
    
    # Update the DataFrame using .loc to avoid warnings
    group_data.loc[:, 'tmp_tmp_interaction'] = group_data['tmp'] * group_data['tmp']
    group_data.loc[:, 'lnfer_irrigation1_tmp_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['tmp']
    group_data.loc[:, 'lnfer_irrigation12_tmp_tmp_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['irrigation1']* group_data['tmp'] * group_data['tmp']
    
    if group != 1:
        group_data.loc[:, 'lnfer_tmp_interaction'] = group_data['lnfer'] * group_data['tmp']
        group_data.loc[:, 'lnfer_tmp_tmp_interaction'] = group_data['lnfer'] * group_data['tmp'] * group_data['tmp']
    
    for i in range(len(irrigation1_range)):
        irrigation1_value = irrigation1_range[i]
        
        # Only plot lines where irrigation is equal to 0 in group 2
        if group == 2 and irrigation1_value != 0:
            continue
        
        for j in range(len(lnfer_range)):
            lnfer_value = lnfer_range[j]
            
            if group == 1:
                y_temp = (reg.params['tmp_tmp_interaction'] * x_temp**2 +
                          reg.params['tmp'] * x_temp +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['irrigation12'] * irrigation1_value**2 +
                          reg.params['lnfer_irrigation1_tmp_interaction'] * irrigation1_value * x_temp * lnfer_value +
                          reg.params['lnfer_irrigation12_tmp_tmp_interaction'] * irrigation1_value**2 * lnfer_value * x_temp**2 +
                          reg.params['const'])
            else:
                y_temp = (reg.params['tmp_tmp_interaction'] * x_temp**2 +
                          reg.params['tmp'] * x_temp +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['lnfer_tmp_interaction'] * x_temp * lnfer_value +
                          reg.params['lnfer_tmp_tmp_interaction'] * lnfer_value * x_temp**2 +
                          reg.params['const'])
            
            # Create a subplot for each group
            ax = plt.subplot(gs[group_idx])
            
            # Calculate the color based on the normalization
            color = cmap_main(norm_main(lnfer_value + irrigation1_value))
            
            ax.plot(x_temp, y_temp, color=color, linewidth=2, alpha=0.8)

            # Collect the color and its corresponding values
            colors.append(color)
            values.append((irrigation1_value, lnfer_value))  # Store the corresponding values
         
            # Set common labels and ticks for x and y axes
            ax.set_xlabel('Air temperature (°C)', fontsize=16, fontfamily='Arial')
            ax.set_ylabel('Ln Yield (kg/ha/yr)', fontsize=16, fontfamily='Arial')
            x_ticks = np.arange(14, 31, 4)
            y_ticks = np.arange(1, 8, 2)
            ax.set_xticks(x_ticks)
            ax.set_yticks(y_ticks)
            ax.set_yticklabels([f'{round(tick, 1)}' for tick in ax.get_yticks()], fontsize=14, fontfamily='Arial')
            ax.set_xticklabels(ax.get_xticks(), fontsize=14, fontfamily='Arial')


# Create a subplot for the color map
ax_color_map = plt.subplot(gs[2])

# Create a scatter plot for the color map subplot
scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)
 

# Set common labels and ticks for x and y axes in the color map subplot
ax_color_map.set_xlabel('Irrigation\n($10^{3}$m$^{3}$/ha/yr)', fontfamily='Arial',fontsize=12)
ax_color_map.set_ylabel('Ln N input (kg/ha/yr)', fontsize=14, fontfamily='Arial')
ax_color_map.set_xticks(np.arange(0, 3, 1))
ax_color_map.set_xticklabels([f'{round(tick, 1)}' for tick in ax_color_map.get_xticks()], fontsize=14, fontfamily='Arial')
ax_color_map.set_yticks(np.arange(0.3, 4.4, 1.0))
ax_color_map.set_yticklabels(ax_color_map.get_yticks(), fontsize=14, fontfamily='Arial')

# Save the entire figure as a JPG with DPI 300
plt.savefig('soybean_tmp.jpg', format='jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)

# Show the plot
plt.show()
/var/folders/vd/0_phd7hx2n51y4412862zww00000gp/T/ipykernel_7375/426696894.py:116: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)
In [5]:
excel_file = "data_ok240115.xlsx"
sheet_name = "maize_pre"
data = pd.read_excel(excel_file, sheet_name=sheet_name)

# Create interaction terms
data['pre_pre_interaction'] = data['pre'] * data['pre']
data['irrigation12'] = data['irrigation1'] * data['irrigation1']
data['lnfer_irrigation1_pre_interaction'] = data['lnfer'] * data['irrigation1'] * data['pre']
data['lnfer_irrigation12_pre_pre_interaction'] = data['lnfer']* data['irrigation1']  * data['irrigation1'] * data['pre'] * data['pre']

data['lnfer_pre_interaction'] = data['lnfer'] * data['pre']
data['lnfer_pre_pre_interaction'] = data['lnfer'] * data['pre'] * data['pre']

# Define predictor variables and response variable
predictor_variables1 = ['lnfer', 'pre', 'irrigation1', 'irrigation12', 'pre_pre_interaction',
                       'lnfer_irrigation1_pre_interaction', 'lnfer_irrigation12_pre_pre_interaction']

predictor_variables2 = ['lnfer', 'pre', 'pre_pre_interaction','irrigation1',
                       'lnfer_pre_interaction', 'lnfer_pre_pre_interaction']

response_variable = 'lnyield'

# Get unique group values
group_values = data['group'].unique()

# Create a figure with a grid layout
fig = plt.figure(figsize=(10, 3.5))
gs = GridSpec(1, 3, wspace=0.35, width_ratios=[4, 4, 1.2])  # Create a 1x3 grid for main plots and color map

# Define a custom gradient color map
cmap_main = plt.get_cmap('RdYlGn')
norm_main = Normalize(5, 10)  # Adjust the normalization range for the main plots color map
colors = []
values = []

for group_idx, group in enumerate(group_values):
    # Filter data for the current group
    group_data = data[data['group'] == group].copy()  # Make a copy to avoid warnings
    
    # Select predictor variables based on the group
    predictor_variables = predictor_variables1 if group == 1 else predictor_variables2
    
    # Perform regression analysis for the current group
    X = sm.add_constant(group_data[predictor_variables])
    y = group_data[response_variable]
    reg = sm.OLS(y, X).fit()
    
    x_pre = np.linspace(0, 12, 100)
    lnfer_range = np.linspace(4.4, 6.9, 100)
    irrigation1_range = np.linspace(0, 4, 100)
    
    # Update the DataFrame using .loc to avoid warnings
    group_data.loc[:, 'pre_pre_interaction'] = group_data['pre'] * group_data['pre']
    group_data.loc[:, 'lnfer_irrigation1_pre_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['pre']
    group_data.loc[:, 'lnfer_irrigation12_pre_pre_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['irrigation1']* group_data['pre'] * group_data['pre']
    
    if group != 1:
        group_data.loc[:, 'lnfer_pre_interaction'] = group_data['lnfer'] * group_data['pre']
        group_data.loc[:, 'lnfer_pre_pre_interaction'] = group_data['lnfer'] * group_data['pre'] * group_data['pre']
    
    for i in range(len(irrigation1_range)):
        irrigation1_value = irrigation1_range[i]
        
        # Only plot lines where irrigation is equal to 0 in group 2
        if group == 2 and irrigation1_value != 0:
            continue
        
        for j in range(len(lnfer_range)):
            lnfer_value = lnfer_range[j]
            
            if group == 1:
                y_pre = (reg.params['pre_pre_interaction'] * x_pre**2 +
                          reg.params['pre'] * x_pre +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['irrigation12'] * irrigation1_value**2 +
                          reg.params['lnfer_irrigation1_pre_interaction'] * irrigation1_value * x_pre * lnfer_value +
                          reg.params['lnfer_irrigation12_pre_pre_interaction'] * irrigation1_value**2 * lnfer_value * x_pre**2 +
                          reg.params['const'])
            else:
                y_pre = (reg.params['pre_pre_interaction'] * x_pre**2 +
                          reg.params['pre'] * x_pre +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['lnfer_pre_interaction'] * x_pre * lnfer_value +
                          reg.params['lnfer_pre_pre_interaction'] * lnfer_value * x_pre**2 +
                          reg.params['const'])
            
            # Create a subplot for each group
            ax = plt.subplot(gs[group_idx])
            
            # Calculate the color based on the normalization
            color = cmap_main(norm_main(lnfer_value + irrigation1_value))
            
            ax.plot(x_pre, y_pre, color=color, linewidth=2, alpha=0.8)

            # Collect the color and its corresponding values
            colors.append(color)
            values.append((irrigation1_value, lnfer_value))  # Store the corresponding values
           
            # Set common labels and ticks for x and y axes
            ax.set_xlabel('Precipitation (100mm)', fontsize=16, fontfamily='Arial')
            ax.set_ylabel('Ln Yield (kg/ha/yr)', fontsize=16, fontfamily='Arial')
            x_ticks = np.arange(0, 13, 4)
            y_ticks = np.arange(2, 5.7, 0.9)
            formatted_y_ticks = [f'{tick:.1f}' for tick in y_ticks]
            plt.yticks(y_ticks, formatted_y_ticks)
            ax.set_xticks(x_ticks)
            ax.set_yticks(y_ticks)
            ax.set_yticklabels(formatted_y_ticks, fontsize=14, fontfamily='Arial')
            ax.set_xticklabels(ax.get_xticks(), fontsize=14, fontfamily='Arial')


# Create a subplot for the color map
ax_color_map = plt.subplot(gs[2])

# Create a scatter plot for the color map subplot
scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)


# Set common labels and ticks for x and y axes in the color map subplot
ax_color_map.set_xlabel('Irrigation\n($10^{3}$m$^{3}$/ha/yr)', fontfamily='Arial',fontsize=12)
ax_color_map.set_ylabel('Ln N input (kg/ha/yr)', fontsize=14, fontfamily='Arial')
ax_color_map.set_xticks(np.arange(0, 5, 2))
ax_color_map.set_xticklabels([f'{round(tick, 1)}' for tick in ax_color_map.get_xticks()], fontsize=14, fontfamily='Arial')
ax_color_map.set_yticks(np.arange(4.4, 7.0, 0.5))
ax_color_map.set_yticklabels(ax_color_map.get_yticks(), fontsize=14, fontfamily='Arial')


plt.savefig('maize_pre.jpg', format='jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)
# Show the plot
plt.show()
/var/folders/vd/0_phd7hx2n51y4412862zww00000gp/T/ipykernel_7375/4069795648.py:118: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)
In [6]:
excel_file = "data_ok240115.xlsx"
sheet_name = "soybean_pre"
data = pd.read_excel(excel_file, sheet_name=sheet_name)

# Create interaction terms
data['pre_pre_interaction'] = data['pre'] * data['pre']
data['irrigation12'] = data['irrigation1'] * data['irrigation1']
data['lnfer_irrigation1_pre_interaction'] = data['lnfer'] * data['irrigation1'] * data['pre']
data['lnfer_irrigation12_pre_pre_interaction'] = data['lnfer']* data['irrigation1']  * data['irrigation1'] * data['pre'] * data['pre']

data['lnfer_pre_interaction'] = data['lnfer'] * data['pre']
data['lnfer_pre_pre_interaction'] = data['lnfer'] * data['pre'] * data['pre']

# Define predictor variables and response variable
predictor_variables1 = ['lnfer', 'pre', 'irrigation1', 'irrigation12', 'pre_pre_interaction',
                       'lnfer_irrigation1_pre_interaction', 'lnfer_irrigation12_pre_pre_interaction']

predictor_variables2 = ['lnfer', 'pre', 'pre_pre_interaction','irrigation1',
                       'lnfer_pre_interaction', 'lnfer_pre_pre_interaction']

response_variable = 'lnyield'

# Get unique group values
group_values = data['group'].unique()

# Create a figure with a grid layout
fig = plt.figure(figsize=(10, 3.5))
gs = GridSpec(1, 3, wspace=0.35, width_ratios=[4, 4, 1.2])  # Create a 1x3 grid for main plots and color map

# Define a custom gradient color map
cmap_main = plt.get_cmap('PuOr')
norm_main = Normalize(1, 5.3)  # Adjust the normalization range for the main plots color map
colors = []
values = []

for group_idx, group in enumerate(group_values):
    # Filter data for the current group
    group_data = data[data['group'] == group].copy()  # Make a copy to avoid warnings
    
    # Select predictor variables based on the group
    predictor_variables = predictor_variables1 if group == 1 else predictor_variables2
    
    # Perform regression analysis for the current group
    X = sm.add_constant(group_data[predictor_variables])
    y = group_data[response_variable]
    reg = sm.OLS(y, X).fit()
    
    x_pre = np.linspace(0, 12, 100)
    lnfer_range = np.linspace(0.3, 4.3, 100)
    irrigation1_range = np.linspace(0, 2, 100)
    
    # Update the DataFrame using .loc to avoid warnings
    group_data.loc[:, 'pre_pre_interaction'] = group_data['pre'] * group_data['pre']
    group_data.loc[:, 'lnfer_irrigation1_pre_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['pre']
    group_data.loc[:, 'lnfer_irrigation12_pre_pre_interaction'] = group_data['lnfer'] * group_data['irrigation1'] * group_data['irrigation1']* group_data['pre'] * group_data['pre']
    
    if group != 1:
        group_data.loc[:, 'lnfer_pre_interaction'] = group_data['lnfer'] * group_data['pre']
        group_data.loc[:, 'lnfer_pre_pre_interaction'] = group_data['lnfer'] * group_data['pre'] * group_data['pre']
    
    for i in range(len(irrigation1_range)):
        irrigation1_value = irrigation1_range[i]
        
        # Only plot lines where irrigation is equal to 0 in group 2
        if group == 2 and irrigation1_value != 0:
            continue
        
        for j in range(len(lnfer_range)):
            lnfer_value = lnfer_range[j]
            
            if group == 1:
                y_pre = (reg.params['pre_pre_interaction'] * x_pre**2 +
                          reg.params['pre'] * x_pre +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['irrigation12'] * irrigation1_value**2 +
                          reg.params['lnfer_irrigation1_pre_interaction'] * irrigation1_value * x_pre * lnfer_value +
                          reg.params['lnfer_irrigation12_pre_pre_interaction'] * irrigation1_value**2 * lnfer_value * x_pre**2 +
                          reg.params['const'])
            else:
                y_pre = (reg.params['pre_pre_interaction'] * x_pre**2 +
                          reg.params['pre'] * x_pre +
                          reg.params['lnfer'] * lnfer_value +
                          reg.params['irrigation1'] * irrigation1_value +
                          reg.params['lnfer_pre_interaction'] * x_pre * lnfer_value +
                          reg.params['lnfer_pre_pre_interaction'] * lnfer_value * x_pre**2 +
                          reg.params['const'])
            
            # Create a subplot for each group
            ax = plt.subplot(gs[group_idx])
            
            # Calculate the color based on the normalization
            color = cmap_main(norm_main(lnfer_value + irrigation1_value))
            
            ax.plot(x_pre, y_pre, color=color, linewidth=2, alpha=0.8)

            # Collect the color and its corresponding values
            colors.append(color)
            values.append((irrigation1_value, lnfer_value))  # Store the corresponding values
           
            # Set common labels and ticks for x and y axes
            ax.set_xlabel('Precipitation (100mm)', fontsize=16, fontfamily='Arial')
            ax.set_ylabel('Ln Yield (kg/ha/yr)', fontsize=16, fontfamily='Arial')
            x_ticks = np.arange(0, 13, 4)
            y_ticks = np.arange(0, 1.9, 0.6)
            ax.set_xticks(x_ticks)
            ax.set_yticks(y_ticks)
            ax.set_yticklabels([f'{round(tick, 1)}' for tick in ax.get_yticks()], fontsize=14, fontfamily='Arial')
            ax.set_xticklabels(ax.get_xticks(), fontsize=14, fontfamily='Arial')


# Create a subplot for the color map
ax_color_map = plt.subplot(gs[2])

# Create a scatter plot for the color map subplot
scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)

# Set common labels and ticks for x and y axes in the color map subplot
ax_color_map.set_xlabel('Irrigation\n($10^{3}$m$^{3}$/ha/yr)', fontfamily='Arial',fontsize=12)
ax_color_map.set_ylabel('Ln N input (kg/ha/yr)', fontsize=14, fontfamily='Arial')
ax_color_map.set_xticks(np.arange(0, 3, 1))
ax_color_map.set_xticklabels([f'{round(tick, 1)}' for tick in ax_color_map.get_xticks()], fontsize=14, fontfamily='Arial')
ax_color_map.set_yticks(np.arange(0.3, 4.4, 1.0))
ax_color_map.set_yticklabels(ax_color_map.get_yticks(), fontsize=14, fontfamily='Arial')


plt.savefig('soybean_pre.jpg', format='jpg', dpi=300, bbox_inches='tight', pad_inches=0.1)

# Show the plot
plt.show()
/var/folders/vd/0_phd7hx2n51y4412862zww00000gp/T/ipykernel_7375/2982540844.py:116: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  scatter = ax_color_map.scatter(*zip(*values), c=colors, cmap=cmap_main, marker='.', s=100)
In [ ]:
 
In [ ]:
 
In [ ]: