import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
[docs]
def plot_cost(j_obs, j_reg, norm_noise, label_reg, ls_reg, save_path=None, file_name=None, file_format='pdf', figsize=(20, 15)):
"""Plot cost
Args:
j_obs (array): observation cost
j_reg (dict of arrays): cost of the regularization terms
norm_noise (float): noise level
label_reg (dict of strings): labels of the regularization terms
ls_reg (dict of string): line styles
save_path (str, optional): Save path. Defaults to None.
file_name (str, optional): plot file name. Defaults to None.
file_format (str, optional): file format. Defaults to 'pdf'.
figsize (tuple, optional): size of the figure. Defaults to (20, 15).
"""
j = np.array(j_obs)
for key in j_reg.keys():
j += np.array(j_reg[key])
fontsize = 25
plt.figure(figsize=figsize)
plt.plot(j, 'c', label='$j$')
plt.plot(j_obs, 'b', label='$j_{obs}$')
for key in j_reg.keys():
plt.plot(j_reg[key], ls_reg[key], label=rf'$\alpha_{{{label_reg[key]}}}j_{{{label_reg[key]}}}$')
plt.plot(norm_noise * np.ones((len(j_obs))), 'k--', label=r'$||\epsilon||_{L^2}^2$')
plt.xlabel(r'iteration', fontsize=fontsize)
plt.ylabel(r'$j$', fontsize=fontsize)
plt.tick_params(labelsize=fontsize - 5)
plt.yscale('log')
plt.legend(fontsize=fontsize)
if save_path is None:
plt.show()
else:
# if the save directory does not exist, then it is created
if not os.path.isdir(Path(save_path)):
os.makedirs(Path(save_path))
if file_name is None:
file_name = 'cost'
plt.savefig(Path(save_path) / f"{file_name}.{file_format}", format=file_format)
plt.close('all')
[docs]
def plot_cost_obs_comparison(case_dict, save_path=None, file_name=None, file_format='pdf', figsize=(20, 15)):
"""Plot observation cost in different cases
Args:
case_dict (dict): The different cases
save_path (str, optional): Save path. Defaults to None.
file_name (str, optional): File name. Defaults to None.
file_format (str, optional): file format. Defaults to 'pdf'.
figsize (tuple, optional): figure size. Defaults to (20, 15).
"""
fontsize = 25
plt.figure(figsize=figsize)
for key in case_dict.keys():
case = case_dict[key]
plt.plot(case['j_obs'], case['ls'], label=case['label'])
plt.xlabel(r'iteration', fontsize=fontsize)
plt.ylabel(r'$j_{obs}$', fontsize=fontsize)
plt.tick_params(labelsize=fontsize - 5)
plt.yscale('log')
plt.legend(fontsize=fontsize)
if save_path is None:
plt.show()
else:
# if the save directory does not exist, then it is created
if not os.path.isdir(Path(save_path)):
os.makedirs(Path(save_path))
if file_name is None:
file_name = 'cost_obs_comparison'
plt.savefig(Path(save_path) / f"{file_name}.{file_format}", format=file_format)
plt.close('all')