mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 08:37:44 +00:00
Detect constant expressions in autograd and return 0 grad early.
This commit is contained in:
+37
-7
@@ -120,6 +120,9 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
|
|
||||||
template <typename T, typename Tuple>
|
template <typename T, typename Tuple>
|
||||||
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
|
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
|
||||||
|
|
||||||
|
template <typename... Ts>
|
||||||
|
constexpr bool AreAllConstantV = (std::remove_reference_t<Ts>::is_constant && ...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename ChildT>
|
template <typename T, typename ChildT>
|
||||||
@@ -167,16 +170,23 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||||
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
|
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
|
||||||
{
|
{
|
||||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
if constexpr (ChildT::is_constant)
|
||||||
|
|
||||||
const auto call_id = std::get<Detail::CallId>(args);
|
|
||||||
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
|
|
||||||
{
|
{
|
||||||
grad_cache_call_id = call_id;
|
return T(0.0);
|
||||||
grad_cache = this_->calculate_grad(args);
|
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||||
|
|
||||||
return *grad_cache;
|
const auto call_id = std::get<Detail::CallId>(args);
|
||||||
|
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
|
||||||
|
{
|
||||||
|
grad_cache_call_id = call_id;
|
||||||
|
grad_cache = this_->calculate_grad(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
return *grad_cache;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... ArgsTs,
|
template <typename... ArgsTs,
|
||||||
@@ -199,6 +209,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
struct VariableParameter : Evaluable<T, VariableParameter<T, I>>
|
struct VariableParameter : Evaluable<T, VariableParameter<T, I>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = false;
|
||||||
|
|
||||||
constexpr VariableParameter()
|
constexpr VariableParameter()
|
||||||
{
|
{
|
||||||
@@ -222,6 +234,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = true;
|
||||||
|
|
||||||
constexpr ConstantParameter()
|
constexpr ConstantParameter()
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -244,6 +258,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = true;
|
||||||
|
|
||||||
constexpr Constant(T x) :
|
constexpr Constant(T x) :
|
||||||
m_x(std::move(x))
|
m_x(std::move(x))
|
||||||
{
|
{
|
||||||
@@ -270,6 +286,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
|
||||||
|
|
||||||
constexpr Sum(LhsT&& lhs, RhsT&& rhs) :
|
constexpr Sum(LhsT&& lhs, RhsT&& rhs) :
|
||||||
m_lhs(std::forward<LhsT>(lhs)),
|
m_lhs(std::forward<LhsT>(lhs)),
|
||||||
m_rhs(std::forward<RhsT>(rhs))
|
m_rhs(std::forward<RhsT>(rhs))
|
||||||
@@ -316,6 +334,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
|
||||||
|
|
||||||
constexpr Difference(LhsT&& lhs, RhsT&& rhs) :
|
constexpr Difference(LhsT&& lhs, RhsT&& rhs) :
|
||||||
m_lhs(std::forward<LhsT>(lhs)),
|
m_lhs(std::forward<LhsT>(lhs)),
|
||||||
m_rhs(std::forward<RhsT>(rhs))
|
m_rhs(std::forward<RhsT>(rhs))
|
||||||
@@ -362,6 +382,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
|
||||||
|
|
||||||
constexpr Product(LhsT&& lhs, RhsT&& rhs) :
|
constexpr Product(LhsT&& lhs, RhsT&& rhs) :
|
||||||
m_lhs(std::forward<LhsT>(lhs)),
|
m_lhs(std::forward<LhsT>(lhs)),
|
||||||
m_rhs(std::forward<RhsT>(rhs))
|
m_rhs(std::forward<RhsT>(rhs))
|
||||||
@@ -408,6 +430,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||||
|
|
||||||
constexpr explicit Negation(ArgT&& x) :
|
constexpr explicit Negation(ArgT&& x) :
|
||||||
m_x(std::forward<ArgT>(x))
|
m_x(std::forward<ArgT>(x))
|
||||||
{
|
{
|
||||||
@@ -440,6 +464,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||||
|
|
||||||
constexpr explicit Sigmoid(ArgT&& x) :
|
constexpr explicit Sigmoid(ArgT&& x) :
|
||||||
m_x(std::forward<ArgT>(x))
|
m_x(std::forward<ArgT>(x))
|
||||||
{
|
{
|
||||||
@@ -482,6 +508,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||||
|
|
||||||
constexpr explicit Pow(ArgT&& x, Id<T> exponent) :
|
constexpr explicit Pow(ArgT&& x, Id<T> exponent) :
|
||||||
m_x(std::forward<ArgT>(x)),
|
m_x(std::forward<ArgT>(x)),
|
||||||
m_exponent(std::move(exponent))
|
m_exponent(std::move(exponent))
|
||||||
@@ -516,6 +544,8 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||||
|
|
||||||
constexpr explicit Log(ArgT&& x) :
|
constexpr explicit Log(ArgT&& x) :
|
||||||
m_x(std::forward<ArgT>(x))
|
m_x(std::forward<ArgT>(x))
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user