Detect constant expressions in autograd and return 0 grad early.

This commit is contained in:
Tomasz Sobczyk
2020-11-30 16:50:51 +01:00
committed by nodchip
parent e975889132
commit cbd973fdaa
+37 -7
View File
@@ -120,6 +120,9 @@ namespace Learner::Autograd::UnivariateStatic
template <typename T, typename Tuple> template <typename T, typename Tuple>
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value; constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
template <typename... Ts>
constexpr bool AreAllConstantV = (std::remove_reference_t<Ts>::is_constant && ...);
} }
template <typename T, typename ChildT> template <typename T, typename ChildT>
@@ -167,16 +170,23 @@ namespace Learner::Autograd::UnivariateStatic
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>> typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const [[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
{ {
const ChildT* this_ = static_cast<const ChildT*>(this); if constexpr (ChildT::is_constant)
const auto call_id = std::get<Detail::CallId>(args);
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
{ {
grad_cache_call_id = call_id; return T(0.0);
grad_cache = this_->calculate_grad(args);
} }
else
{
const ChildT* this_ = static_cast<const ChildT*>(this);
return *grad_cache; const auto call_id = std::get<Detail::CallId>(args);
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
{
grad_cache_call_id = call_id;
grad_cache = this_->calculate_grad(args);
}
return *grad_cache;
}
} }
template <typename... ArgsTs, template <typename... ArgsTs,
@@ -199,6 +209,8 @@ namespace Learner::Autograd::UnivariateStatic
struct VariableParameter : Evaluable<T, VariableParameter<T, I>> struct VariableParameter : Evaluable<T, VariableParameter<T, I>>
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = false;
constexpr VariableParameter() constexpr VariableParameter()
{ {
@@ -222,6 +234,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = true;
constexpr ConstantParameter() constexpr ConstantParameter()
{ {
} }
@@ -244,6 +258,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = true;
constexpr Constant(T x) : constexpr Constant(T x) :
m_x(std::move(x)) m_x(std::move(x))
{ {
@@ -270,6 +286,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
constexpr Sum(LhsT&& lhs, RhsT&& rhs) : constexpr Sum(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)), m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs)) m_rhs(std::forward<RhsT>(rhs))
@@ -316,6 +334,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
constexpr Difference(LhsT&& lhs, RhsT&& rhs) : constexpr Difference(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)), m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs)) m_rhs(std::forward<RhsT>(rhs))
@@ -362,6 +382,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
constexpr Product(LhsT&& lhs, RhsT&& rhs) : constexpr Product(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)), m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs)) m_rhs(std::forward<RhsT>(rhs))
@@ -408,6 +430,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Negation(ArgT&& x) : constexpr explicit Negation(ArgT&& x) :
m_x(std::forward<ArgT>(x)) m_x(std::forward<ArgT>(x))
{ {
@@ -440,6 +464,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Sigmoid(ArgT&& x) : constexpr explicit Sigmoid(ArgT&& x) :
m_x(std::forward<ArgT>(x)) m_x(std::forward<ArgT>(x))
{ {
@@ -482,6 +508,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Pow(ArgT&& x, Id<T> exponent) : constexpr explicit Pow(ArgT&& x, Id<T> exponent) :
m_x(std::forward<ArgT>(x)), m_x(std::forward<ArgT>(x)),
m_exponent(std::move(exponent)) m_exponent(std::move(exponent))
@@ -516,6 +544,8 @@ namespace Learner::Autograd::UnivariateStatic
{ {
using ValueType = T; using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Log(ArgT&& x) : constexpr explicit Log(ArgT&& x) :
m_x(std::forward<ArgT>(x)) m_x(std::forward<ArgT>(x))
{ {