mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 16:47:37 +00:00
Move cross_entropy calculation to a separate function.
This commit is contained in:
+24
-9
@@ -197,6 +197,29 @@ namespace Learner
|
|||||||
return lambda;
|
return lambda;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ShallowT, typename TeacherT, typename ResultT, typename LambdaT>
|
||||||
|
static auto& cross_entropy_(
|
||||||
|
ShallowT& q_,
|
||||||
|
TeacherT& p_,
|
||||||
|
ResultT& t_,
|
||||||
|
LambdaT& lambda_
|
||||||
|
)
|
||||||
|
{
|
||||||
|
using namespace Learner::Autograd::UnivariateStatic;
|
||||||
|
|
||||||
|
constexpr double epsilon = 1e-12;
|
||||||
|
|
||||||
|
static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon));
|
||||||
|
static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon));
|
||||||
|
static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_));
|
||||||
|
static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_));
|
||||||
|
static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_;
|
||||||
|
static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_;
|
||||||
|
static thread_local auto loss_ = result_ - entropy_;
|
||||||
|
|
||||||
|
return loss_;
|
||||||
|
}
|
||||||
|
|
||||||
static ValueWithGrad<double> get_loss(Value shallow, Value teacher_signal, int result, int ply)
|
static ValueWithGrad<double> get_loss(Value shallow, Value teacher_signal, int result, int ply)
|
||||||
{
|
{
|
||||||
using namespace Learner::Autograd::UnivariateStatic;
|
using namespace Learner::Autograd::UnivariateStatic;
|
||||||
@@ -215,19 +238,11 @@ namespace Learner
|
|||||||
auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0));
|
auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0));
|
||||||
*/
|
*/
|
||||||
|
|
||||||
constexpr double epsilon = 1e-12;
|
|
||||||
|
|
||||||
static thread_local auto q_ = sigmoid(VariableParameter<double, 0>{} * ConstantParameter<double, 4>{});
|
static thread_local auto q_ = sigmoid(VariableParameter<double, 0>{} * ConstantParameter<double, 4>{});
|
||||||
static thread_local auto p_ = sigmoid(ConstantParameter<double, 1>{} * ConstantParameter<double, 4>{});
|
static thread_local auto p_ = sigmoid(ConstantParameter<double, 1>{} * ConstantParameter<double, 4>{});
|
||||||
static thread_local auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
|
static thread_local auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
|
||||||
static thread_local auto lambda_ = ConstantParameter<double, 3>{};
|
static thread_local auto lambda_ = ConstantParameter<double, 3>{};
|
||||||
static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon));
|
static thread_local auto loss_ = cross_entropy_(q_, p_, t_, lambda_);
|
||||||
static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon));
|
|
||||||
static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_));
|
|
||||||
static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_));
|
|
||||||
static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_;
|
|
||||||
static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_;
|
|
||||||
static thread_local auto loss_ = result_ - entropy_;
|
|
||||||
|
|
||||||
auto args = std::tuple(
|
auto args = std::tuple(
|
||||||
(double)shallow,
|
(double)shallow,
|
||||||
|
|||||||
Reference in New Issue
Block a user