Add regularization parameter to make the fitting more robust.

This commit is contained in:
exapde
2022-12-02 10:29:46 -05:00
parent 539f5b2fcb
commit e4791356c7
4 changed files with 11 additions and 5 deletions

View File

@ -6,6 +6,7 @@ path_to_test_data_set "XYZ"
fitting_weight_energy 100.0
fitting_weight_force 1.0
fitting_regularization_parameter 1e-10
error_analysis_for_training_data_set 1
error_analysis_for_test_data_set 0

View File

@ -6,6 +6,7 @@ path_to_test_data_set "../Ta/XYZ"
fitting_weight_energy 100.0
fitting_weight_force 1.0
fitting_regularization_parameter 1e-10
error_analysis_for_training_data_set 1
error_analysis_for_test_data_set 0

View File

@ -197,6 +197,7 @@ void FitPOD::read_data_file(double *fitting_weights, std::string &file_format,
if (keywd == "fraction_test_data_set") fitting_weights[8] = utils::numeric(FLERR,words[1],false,lmp);
if (keywd == "randomize_training_data_set") fitting_weights[9] = utils::numeric(FLERR,words[1],false,lmp);
if (keywd == "randomize_test_data_set") fitting_weights[10] = utils::numeric(FLERR,words[1],false,lmp);
if (keywd == "fitting_regularization_parameter") fitting_weights[11] = utils::numeric(FLERR,words[1],false,lmp);
// other settings
@ -223,6 +224,7 @@ void FitPOD::read_data_file(double *fitting_weights, std::string &file_format,
utils::logmesg(lmp, "fitting weight for energy: {}\n", fitting_weights[0]);
utils::logmesg(lmp, "fitting weight for force: {}\n", fitting_weights[1]);
utils::logmesg(lmp, "fitting weight for stress: {}\n", fitting_weights[2]);
utils::logmesg(lmp, "fitting regularization parameter: {}\n", fitting_weights[11]);
utils::logmesg(lmp, "**************** End of Data File ****************\n");
}
}
@ -1269,10 +1271,12 @@ void FitPOD::least_squares_fit(datastruct data)
for (int i = 0; i<nd*nd; i++)
desc.A[i] = desc.A[i]*maxb;
double regularizing_parameter = data.fitting_weights[11];
for (int i = 0; i<nd; i++) {
desc.c[i] = desc.b[i];
desc.A[i + nd*i] = desc.A[i + nd*i]*(1.0 + SMALL);
if (desc.A[i + nd*i] < SMALL) desc.A[i + nd*i] = SMALL;
desc.A[i + nd*i] = desc.A[i + nd*i]*(1.0 + regularizing_parameter);
if (desc.A[i + nd*i] < regularizing_parameter) desc.A[i + nd*i] = regularizing_parameter;
}
// solving the linear system A * c = b

View File

@ -62,7 +62,7 @@ public:
int randomize = 1;
double fraction = 1.0;
double fitting_weights[12] = {0.0, 0.0, 0.0, 1, 1, 0, 0, 1, 1, 1, 1, 0};
double fitting_weights[12] = {100.0, 1.0, 0.0, 1, 1, 0, 0, 1, 1, 1, 1, 1e-10};
void copydatainfo(datastruct &data) {
data.data_path = data_path;