mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 16:47:37 +00:00
Make all grad related functions in learn static. Pass calc_grad as a parameter.
This commit is contained in:
+17
-23
@@ -185,7 +185,7 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A function that converts the evaluation value to the winning rate [0,1]
|
// A function that converts the evaluation value to the winning rate [0,1]
|
||||||
double winning_percentage(double value)
|
static double winning_percentage(double value)
|
||||||
{
|
{
|
||||||
// 1/(1+10^(-Eval/4))
|
// 1/(1+10^(-Eval/4))
|
||||||
// = 1/(1+e^(-Eval/4*ln(10))
|
// = 1/(1+e^(-Eval/4*ln(10))
|
||||||
@@ -194,7 +194,7 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A function that converts the evaluation value to the winning rate [0,1]
|
// A function that converts the evaluation value to the winning rate [0,1]
|
||||||
double winning_percentage_wdl(double value, int ply)
|
static double winning_percentage_wdl(double value, int ply)
|
||||||
{
|
{
|
||||||
constexpr double wdl_total = 1000.0;
|
constexpr double wdl_total = 1000.0;
|
||||||
constexpr double draw_score = 0.5;
|
constexpr double draw_score = 0.5;
|
||||||
@@ -207,7 +207,7 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A function that converts the evaluation value to the winning rate [0,1]
|
// A function that converts the evaluation value to the winning rate [0,1]
|
||||||
double winning_percentage(double value, int ply)
|
static double winning_percentage(double value, int ply)
|
||||||
{
|
{
|
||||||
if (use_wdl)
|
if (use_wdl)
|
||||||
{
|
{
|
||||||
@@ -219,7 +219,7 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double calc_cross_entropy_of_winning_percentage(
|
static double calc_cross_entropy_of_winning_percentage(
|
||||||
double deep_win_rate,
|
double deep_win_rate,
|
||||||
double shallow_eval,
|
double shallow_eval,
|
||||||
int ply)
|
int ply)
|
||||||
@@ -229,7 +229,7 @@ namespace Learner
|
|||||||
return -p * std::log(q) - (1.0 - p) * std::log(1.0 - q);
|
return -p * std::log(q) - (1.0 - p) * std::log(1.0 - q);
|
||||||
}
|
}
|
||||||
|
|
||||||
double calc_d_cross_entropy_of_winning_percentage(
|
static double calc_d_cross_entropy_of_winning_percentage(
|
||||||
double deep_win_rate,
|
double deep_win_rate,
|
||||||
double shallow_eval,
|
double shallow_eval,
|
||||||
int ply)
|
int ply)
|
||||||
@@ -248,7 +248,7 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Training Formula · Issue #71 · nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71
|
// Training Formula · Issue #71 · nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71
|
||||||
double get_scaled_signal(double signal)
|
static double get_scaled_signal(double signal)
|
||||||
{
|
{
|
||||||
double scaled_signal = signal;
|
double scaled_signal = signal;
|
||||||
|
|
||||||
@@ -266,13 +266,13 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Teacher winning probability.
|
// Teacher winning probability.
|
||||||
double calculate_p(double teacher_signal, int ply)
|
static double calculate_p(double teacher_signal, int ply)
|
||||||
{
|
{
|
||||||
const double scaled_teacher_signal = get_scaled_signal(teacher_signal);
|
const double scaled_teacher_signal = get_scaled_signal(teacher_signal);
|
||||||
return winning_percentage(scaled_teacher_signal, ply);
|
return winning_percentage(scaled_teacher_signal, ply);
|
||||||
}
|
}
|
||||||
|
|
||||||
double calculate_lambda(double teacher_signal)
|
static double calculate_lambda(double teacher_signal)
|
||||||
{
|
{
|
||||||
// If the evaluation value in deep search exceeds elmo_lambda_limit
|
// If the evaluation value in deep search exceeds elmo_lambda_limit
|
||||||
// then apply elmo_lambda_high instead of elmo_lambda_low.
|
// then apply elmo_lambda_high instead of elmo_lambda_low.
|
||||||
@@ -284,7 +284,7 @@ namespace Learner
|
|||||||
return lambda;
|
return lambda;
|
||||||
}
|
}
|
||||||
|
|
||||||
double calculate_t(int game_result)
|
static double calculate_t(int game_result)
|
||||||
{
|
{
|
||||||
// Use 1 as the correction term if the expected win rate is 1,
|
// Use 1 as the correction term if the expected win rate is 1,
|
||||||
// 0 if you lose, and 0.5 if you draw.
|
// 0 if you lose, and 0.5 if you draw.
|
||||||
@@ -294,20 +294,20 @@ namespace Learner
|
|||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
double calc_grad(Value teacher_signal, Value shallow, const PackedSfenValue& psv)
|
static double calc_grad(Value shallow, Value teacher_signal, int result, int ply)
|
||||||
{
|
{
|
||||||
// elmo (WCSC27) method
|
// elmo (WCSC27) method
|
||||||
// Correct with the actual game wins and losses.
|
// Correct with the actual game wins and losses.
|
||||||
const double q = winning_percentage(shallow, psv.gamePly);
|
const double q = winning_percentage(shallow, ply);
|
||||||
const double p = calculate_p(teacher_signal, psv.gamePly);
|
const double p = calculate_p(teacher_signal, ply);
|
||||||
const double t = calculate_t(psv.game_result);
|
const double t = calculate_t(result);
|
||||||
const double lambda = calculate_lambda(teacher_signal);
|
const double lambda = calculate_lambda(teacher_signal);
|
||||||
|
|
||||||
double grad;
|
double grad;
|
||||||
if (use_wdl)
|
if (use_wdl)
|
||||||
{
|
{
|
||||||
const double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly);
|
const double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, ply);
|
||||||
const double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly);
|
const double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, ply);
|
||||||
grad = lambda * dce_p + (1.0 - lambda) * dce_t;
|
grad = lambda * dce_p + (1.0 - lambda) * dce_t;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -324,7 +324,7 @@ namespace Learner
|
|||||||
// The individual cross entropy of the win/loss term and win
|
// The individual cross entropy of the win/loss term and win
|
||||||
// rate term of the elmo expression is returned
|
// rate term of the elmo expression is returned
|
||||||
// to the arguments cross_entropy_eval and cross_entropy_win.
|
// to the arguments cross_entropy_eval and cross_entropy_win.
|
||||||
Loss calc_cross_entropy(
|
static Loss calc_cross_entropy(
|
||||||
Value teacher_signal,
|
Value teacher_signal,
|
||||||
Value shallow,
|
Value shallow,
|
||||||
const PackedSfenValue& psv)
|
const PackedSfenValue& psv)
|
||||||
@@ -360,12 +360,6 @@ namespace Learner
|
|||||||
return loss;
|
return loss;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Other objective functions may be considered in the future...
|
|
||||||
double calc_grad(Value shallow, const PackedSfenValue& psv)
|
|
||||||
{
|
|
||||||
return calc_grad((Value)psv.score, shallow, psv);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Class to generate sfen with multiple threads
|
// Class to generate sfen with multiple threads
|
||||||
struct LearnerThink
|
struct LearnerThink
|
||||||
{
|
{
|
||||||
@@ -703,7 +697,7 @@ namespace Learner
|
|||||||
// should be no real issues happening since
|
// should be no real issues happening since
|
||||||
// the read/write phases are isolated.
|
// the read/write phases are isolated.
|
||||||
atomic_thread_fence(memory_order_seq_cst);
|
atomic_thread_fence(memory_order_seq_cst);
|
||||||
Eval::NNUE::update_parameters(epoch, params.verbose);
|
Eval::NNUE::update_parameters(epoch, params.verbose, calc_grad);
|
||||||
atomic_thread_fence(memory_order_seq_cst);
|
atomic_thread_fence(memory_order_seq_cst);
|
||||||
|
|
||||||
if (++save_count * params.mini_batch_size >= params.eval_save_interval)
|
if (++save_count * params.mini_batch_size >= params.eval_save_interval)
|
||||||
|
|||||||
+2
-2
@@ -64,10 +64,10 @@ namespace Learner
|
|||||||
// rmse calculation is done in one thread, so it takes some time, so reducing the output is effective.
|
// rmse calculation is done in one thread, so it takes some time, so reducing the output is effective.
|
||||||
constexpr std::size_t LEARN_RMSE_OUTPUT_INTERVAL = 1;
|
constexpr std::size_t LEARN_RMSE_OUTPUT_INTERVAL = 1;
|
||||||
|
|
||||||
double calc_grad(Value shallow, const PackedSfenValue& psv);
|
|
||||||
|
|
||||||
// Learning from the generated game record
|
// Learning from the generated game record
|
||||||
void learn(std::istringstream& is);
|
void learn(std::istringstream& is);
|
||||||
|
|
||||||
|
using CalcGradFunc = double(Value, Value, int, int);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // ifndef _LEARN_H_
|
#endif // ifndef _LEARN_H_
|
||||||
|
|||||||
@@ -18,8 +18,6 @@
|
|||||||
#include "misc.h"
|
#include "misc.h"
|
||||||
#include "thread_win32_osx.h"
|
#include "thread_win32_osx.h"
|
||||||
|
|
||||||
#include "learn/learn.h"
|
|
||||||
|
|
||||||
// Learning rate scale
|
// Learning rate scale
|
||||||
double global_learning_rate;
|
double global_learning_rate;
|
||||||
|
|
||||||
@@ -183,7 +181,7 @@ namespace Eval::NNUE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// update the evaluation function parameters
|
// update the evaluation function parameters
|
||||||
void update_parameters(uint64_t epoch, bool verbose) {
|
void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad) {
|
||||||
assert(batch_size > 0);
|
assert(batch_size > 0);
|
||||||
|
|
||||||
const auto learning_rate = static_cast<LearnFloatType>(
|
const auto learning_rate = static_cast<LearnFloatType>(
|
||||||
@@ -210,7 +208,8 @@ namespace Eval::NNUE {
|
|||||||
batch[b].sign * network_output[b] * kPonanzaConstant));
|
batch[b].sign * network_output[b] * kPonanzaConstant));
|
||||||
const auto discrete = batch[b].sign * batch[b].discrete_nn_eval;
|
const auto discrete = batch[b].sign * batch[b].discrete_nn_eval;
|
||||||
const auto& psv = batch[b].psv;
|
const auto& psv = batch[b].psv;
|
||||||
const double gradient = batch[b].sign * Learner::calc_grad(shallow, psv);
|
const double gradient =
|
||||||
|
batch[b].sign * calc_grad(shallow, (Value)psv.score, psv.game_result, psv.gamePly);
|
||||||
gradients[b] = static_cast<LearnFloatType>(gradient * batch[b].weight);
|
gradients[b] = static_cast<LearnFloatType>(gradient * batch[b].weight);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ namespace Eval::NNUE {
|
|||||||
double weight);
|
double weight);
|
||||||
|
|
||||||
// update the evaluation function parameters
|
// update the evaluation function parameters
|
||||||
void update_parameters(uint64_t epoch, bool verbose);
|
void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad);
|
||||||
|
|
||||||
// Check if there are any problems with learning
|
// Check if there are any problems with learning
|
||||||
void check_health();
|
void check_health();
|
||||||
|
|||||||
Reference in New Issue
Block a user