mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 13:17:44 +00:00
Cross entropy loss.
This commit is contained in:
+34
-2
@@ -282,6 +282,38 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
return Product(Constant(lhs), std::move(rhs));
|
return Product(Constant(lhs), std::move(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
|
struct Negation : Evaluable<Negation<ArgT, T>>
|
||||||
|
{
|
||||||
|
using ValueType = T;
|
||||||
|
|
||||||
|
explicit Negation(ArgT x) :
|
||||||
|
m_x(std::move(x))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... ArgsTs>
|
||||||
|
T value(const std::tuple<ArgsTs...>& args) const
|
||||||
|
{
|
||||||
|
return -m_x.value(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... ArgsTs>
|
||||||
|
T grad(const std::tuple<ArgsTs...>& args) const
|
||||||
|
{
|
||||||
|
return -m_x.grad(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ArgT m_x;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
|
auto operator-(ArgT x)
|
||||||
|
{
|
||||||
|
return Negation(std::move(x));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
struct Sigmoid : Evaluable<Sigmoid<ArgT, T>>
|
struct Sigmoid : Evaluable<Sigmoid<ArgT, T>>
|
||||||
{
|
{
|
||||||
@@ -318,7 +350,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ArgT>
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
auto sigmoid(ArgT x)
|
auto sigmoid(ArgT x)
|
||||||
{
|
{
|
||||||
return Sigmoid(std::move(x));
|
return Sigmoid(std::move(x));
|
||||||
@@ -394,7 +426,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ArgT>
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
auto log(ArgT x)
|
auto log(ArgT x)
|
||||||
{
|
{
|
||||||
return Log(std::move(x));
|
return Log(std::move(x));
|
||||||
|
|||||||
@@ -200,11 +200,14 @@ namespace Learner
|
|||||||
static ValueWithGrad<double> get_loss(Value shallow, Value teacher_signal, int result, int ply)
|
static ValueWithGrad<double> get_loss(Value shallow, Value teacher_signal, int result, int ply)
|
||||||
{
|
{
|
||||||
using namespace Learner::Autograd::UnivariateStatic;
|
using namespace Learner::Autograd::UnivariateStatic;
|
||||||
|
|
||||||
|
/*
|
||||||
auto q_ = sigmoid(VariableParameter<double, 0>{} * winning_probability_coefficient);
|
auto q_ = sigmoid(VariableParameter<double, 0>{} * winning_probability_coefficient);
|
||||||
auto p_ = sigmoid(ConstantParameter<double, 1>{} * winning_probability_coefficient);
|
auto p_ = sigmoid(ConstantParameter<double, 1>{} * winning_probability_coefficient);
|
||||||
auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
|
auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
|
||||||
auto lambda_ = ConstantParameter<double, 3>{};
|
auto lambda_ = ConstantParameter<double, 3>{};
|
||||||
auto loss_ = pow(lambda_ * (q_ - p_) + (1.0 - lambda_) * (q_ - t_), 2.0);
|
auto loss_ = pow(lambda_ * (q_ - p_) + (1.0 - lambda_) * (q_ - t_), 2.0);
|
||||||
|
*/
|
||||||
|
|
||||||
/*
|
/*
|
||||||
auto q_ = VariableParameter<double, 0>{};
|
auto q_ = VariableParameter<double, 0>{};
|
||||||
@@ -212,6 +215,20 @@ namespace Learner
|
|||||||
auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0));
|
auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0));
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
const double epsilon = 1e-12;
|
||||||
|
|
||||||
|
auto q_ = sigmoid(VariableParameter<double, 0>{} * winning_probability_coefficient);
|
||||||
|
auto p_ = sigmoid(ConstantParameter<double, 1>{} * winning_probability_coefficient);
|
||||||
|
auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
|
||||||
|
auto lambda_ = ConstantParameter<double, 3>{};
|
||||||
|
auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon));
|
||||||
|
auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon));
|
||||||
|
auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_));
|
||||||
|
auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_));
|
||||||
|
auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_;
|
||||||
|
auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_;
|
||||||
|
auto loss_ = result_ - entropy_;
|
||||||
|
|
||||||
auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal));
|
auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal));
|
||||||
return loss_.eval(args);
|
return loss_.eval(args);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user