Use winning_percentage_wdl in learn

This commit is contained in:
tttak
2020-08-24 22:56:08 +09:00
parent 7ee8a2bbb7
commit 4ce30d9522
3 changed files with 80 additions and 21 deletions
+56 -5
View File
@@ -133,6 +133,8 @@ double dest_score_max_value = 1.0;
// probabilities in the trainer. Sometimes we want to use the winning probabilities in the training
// data directly. In those cases, we set false to this variable.
bool convert_teacher_signal_to_winning_probability = true;
// Using WDL with win rate model instead of sigmoid
bool use_wdl = false;
// -----------------------------------
// write phase file
@@ -1162,6 +1164,45 @@ double winning_percentage(double value)
// = sigmoid(Eval/4*ln(10))
return sigmoid(value * winning_probability_coefficient);
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage_wdl(double value, int ply)
{
double wdl_w = UCI::win_rate_model_double( value, ply);
double wdl_l = UCI::win_rate_model_double(-value, ply);
double wdl_d = 1000.0 - wdl_w - wdl_l;
return (wdl_w + wdl_d / 2.0) / 1000.0;
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage(double value, int ply)
{
if (use_wdl) {
return winning_percentage_wdl(value, ply);
}
else {
return winning_percentage(value);
}
}
double calc_cross_entropy_of_winning_percentage(double deep_win_rate, double shallow_eval, int ply)
{
double p = deep_win_rate;
double q = winning_percentage(shallow_eval, ply);
return -p * std::log(q) - (1 - p) * std::log(1 - q);
}
double calc_d_cross_entropy_of_winning_percentage(double deep_win_rate, double shallow_eval, int ply)
{
constexpr double epsilon = 0.000001;
double y1 = calc_cross_entropy_of_winning_percentage(deep_win_rate, shallow_eval , ply);
double y2 = calc_cross_entropy_of_winning_percentage(deep_win_rate, shallow_eval + epsilon, ply);
// Divide by the winning_probability_coefficient to match scale with the sigmoidal win rate
return ((y2 - y1) / epsilon) / winning_probability_coefficient;
}
double dsigmoid(double x)
{
// Sigmoid function
@@ -1263,11 +1304,11 @@ double calc_grad(Value teacher_signal, Value shallow , const PackedSfenValue& ps
// Scale to [dest_score_min_value, dest_score_max_value].
scaled_teacher_signal = scaled_teacher_signal * (dest_score_max_value - dest_score_min_value) + dest_score_min_value;
const double q = winning_percentage(shallow);
const double q = winning_percentage(shallow, psv.gamePly);
// Teacher winning probability.
double p = scaled_teacher_signal;
if (convert_teacher_signal_to_winning_probability) {
p = winning_percentage(scaled_teacher_signal);
p = winning_percentage(scaled_teacher_signal, psv.gamePly);
}
// Use 1 as the correction term if the expected win rate is 1, 0 if you lose, and 0.5 if you draw.
@@ -1277,9 +1318,17 @@ double calc_grad(Value teacher_signal, Value shallow , const PackedSfenValue& ps
// If the evaluation value in deep search exceeds ELMO_LAMBDA_LIMIT, apply ELMO_LAMBDA2 instead of ELMO_LAMBDA.
const double lambda = (abs(teacher_signal) >= ELMO_LAMBDA_LIMIT) ? ELMO_LAMBDA2 : ELMO_LAMBDA;
// Use the actual win rate as a correction term.
// This is the idea of elmo (WCSC27), modern O-parts.
const double grad = lambda * (q - p) + (1.0 - lambda) * (q - t);
double grad;
if (use_wdl) {
double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly);
double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly);
grad = lambda * dce_p + (1.0 - lambda) * dce_t;
}
else {
// Use the actual win rate as a correction term.
// This is the idea of elmo (WCSC27), modern O-parts.
grad = lambda * (q - p) + (1.0 - lambda) * (q - t);
}
return grad;
}
@@ -3168,6 +3217,8 @@ void learn(Position&, istringstream& is)
else if (option == "winning_probability_coefficient") is >> winning_probability_coefficient;
// Discount rate
else if (option == "discount_rate") is >> discount_rate;
// Using WDL with win rate model instead of sigmoid
else if (option == "use_wdl") is >> use_wdl;
// No learning of KK/KKP/KPP/KPPP.
else if (option == "freeze_kk") is >> freeze[0];