mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 12:07:43 +00:00
When forming an autograd expression only copy parts that are rvalue references, store references to lvalues.
This commit is contained in:
+80
-73
@@ -69,6 +69,13 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
using Id = typename Identity<T>::type;
|
using Id = typename Identity<T>::type;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using StoreValueOrRef = std::conditional_t<
|
||||||
|
std::is_rvalue_reference_v<T>,
|
||||||
|
std::remove_reference_t<T>,
|
||||||
|
const std::remove_reference_t<T>&
|
||||||
|
>;
|
||||||
|
|
||||||
template <typename T, typename ChildT>
|
template <typename T, typename ChildT>
|
||||||
struct Evaluable
|
struct Evaluable
|
||||||
{
|
{
|
||||||
@@ -179,14 +186,14 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
T m_x;
|
T m_x;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>>
|
struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
Sum(LhsT lhs, RhsT rhs) :
|
Sum(LhsT&& lhs, RhsT&& rhs) :
|
||||||
m_lhs(std::move(lhs)),
|
m_lhs(std::forward<LhsT>(lhs)),
|
||||||
m_rhs(std::move(rhs))
|
m_rhs(std::forward<RhsT>(rhs))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,36 +210,36 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LhsT m_lhs;
|
StoreValueOrRef<LhsT> m_lhs;
|
||||||
RhsT m_rhs;
|
StoreValueOrRef<RhsT> m_rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
auto operator+(LhsT lhs, RhsT rhs)
|
auto operator+(LhsT&& lhs, RhsT&& rhs)
|
||||||
{
|
{
|
||||||
return Sum(std::move(lhs), std::move(rhs));
|
return Sum<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
auto operator+(LhsT lhs, Id<T> rhs)
|
auto operator+(LhsT&& lhs, Id<T> rhs)
|
||||||
{
|
{
|
||||||
return Sum(std::move(lhs), Constant(rhs));
|
return Sum<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename RhsT, typename T = typename RhsT::ValueType>
|
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
|
||||||
auto operator+(Id<T> lhs, RhsT rhs)
|
auto operator+(Id<T> lhs, RhsT&& rhs)
|
||||||
{
|
{
|
||||||
return Sum(Constant(lhs), std::move(rhs));
|
return Sum<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>>
|
struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
Difference(LhsT lhs, RhsT rhs) :
|
Difference(LhsT&& lhs, RhsT&& rhs) :
|
||||||
m_lhs(std::move(lhs)),
|
m_lhs(std::forward<LhsT>(lhs)),
|
||||||
m_rhs(std::move(rhs))
|
m_rhs(std::forward<RhsT>(rhs))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,36 +256,36 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LhsT m_lhs;
|
StoreValueOrRef<LhsT> m_lhs;
|
||||||
RhsT m_rhs;
|
StoreValueOrRef<RhsT> m_rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
auto operator-(LhsT lhs, RhsT rhs)
|
auto operator-(LhsT&& lhs, RhsT&& rhs)
|
||||||
{
|
{
|
||||||
return Difference(std::move(lhs), std::move(rhs));
|
return Difference<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
auto operator-(LhsT lhs, Id<T> rhs)
|
auto operator-(LhsT&& lhs, Id<T> rhs)
|
||||||
{
|
{
|
||||||
return Difference(std::move(lhs), Constant(rhs));
|
return Difference<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename RhsT, typename T = typename RhsT::ValueType>
|
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
|
||||||
auto operator-(Id<T> lhs, RhsT rhs)
|
auto operator-(Id<T> lhs, RhsT&& rhs)
|
||||||
{
|
{
|
||||||
return Difference(Constant(lhs), std::move(rhs));
|
return Difference<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
struct Product : Evaluable<T, Product<LhsT, RhsT, T>>
|
struct Product : Evaluable<T, Product<LhsT, RhsT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
Product(LhsT lhs, RhsT rhs) :
|
Product(LhsT&& lhs, RhsT&& rhs) :
|
||||||
m_lhs(std::move(lhs)),
|
m_lhs(std::forward<LhsT>(lhs)),
|
||||||
m_rhs(std::move(rhs))
|
m_rhs(std::forward<RhsT>(rhs))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,35 +302,35 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LhsT m_lhs;
|
StoreValueOrRef<LhsT> m_lhs;
|
||||||
RhsT m_rhs;
|
StoreValueOrRef<RhsT> m_rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
auto operator*(LhsT lhs, RhsT rhs)
|
auto operator*(LhsT&& lhs, RhsT&& rhs)
|
||||||
{
|
{
|
||||||
return Product(std::move(lhs), std::move(rhs));
|
return Product<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||||
auto operator*(LhsT lhs, Id<T> rhs)
|
auto operator*(LhsT&& lhs, Id<T> rhs)
|
||||||
{
|
{
|
||||||
return Product(std::move(lhs), Constant(rhs));
|
return Product<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename RhsT, typename T = typename RhsT::ValueType>
|
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
|
||||||
auto operator*(Id<T> lhs, RhsT rhs)
|
auto operator*(Id<T> lhs, RhsT&& rhs)
|
||||||
{
|
{
|
||||||
return Product(Constant(lhs), std::move(rhs));
|
return Product<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
struct Negation : Evaluable<T, Negation<ArgT, T>>
|
struct Negation : Evaluable<T, Negation<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
explicit Negation(ArgT x) :
|
explicit Negation(ArgT&& x) :
|
||||||
m_x(std::move(x))
|
m_x(std::forward<ArgT>(x))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,22 +347,22 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArgT m_x;
|
StoreValueOrRef<ArgT> m_x;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
auto operator-(ArgT x)
|
auto operator-(ArgT&& x)
|
||||||
{
|
{
|
||||||
return Negation(std::move(x));
|
return Negation<ArgT&&>(std::forward<ArgT>(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>>
|
struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
explicit Sigmoid(ArgT x) :
|
explicit Sigmoid(ArgT&& x) :
|
||||||
m_x(std::move(x))
|
m_x(std::forward<ArgT>(x))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,7 +379,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArgT m_x;
|
StoreValueOrRef<ArgT> m_x;
|
||||||
|
|
||||||
T value_(T x) const
|
T value_(T x) const
|
||||||
{
|
{
|
||||||
@@ -385,19 +392,19 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
auto sigmoid(ArgT x)
|
auto sigmoid(ArgT&& x)
|
||||||
{
|
{
|
||||||
return Sigmoid(std::move(x));
|
return Sigmoid<ArgT&&>(std::forward<ArgT>(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
struct Pow : Evaluable<T, Pow<ArgT, T>>
|
struct Pow : Evaluable<T, Pow<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
explicit Pow(ArgT x, Id<T> exponent) :
|
explicit Pow(ArgT&& x, Id<T> exponent) :
|
||||||
m_x(std::move(x)),
|
m_x(std::forward<ArgT>(x)),
|
||||||
m_exponent(std::move(exponent))
|
m_exponent(std::move(exponent))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -415,23 +422,23 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArgT m_x;
|
StoreValueOrRef<ArgT> m_x;
|
||||||
T m_exponent;
|
T m_exponent;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
auto pow(ArgT x, Id<T> exp)
|
auto pow(ArgT&& x, Id<T> exp)
|
||||||
{
|
{
|
||||||
return Pow(std::move(x), std::move(exp));
|
return Pow<ArgT&&>(std::forward<ArgT>(x), std::move(exp));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
struct Log : Evaluable<T, Log<ArgT, T>>
|
struct Log : Evaluable<T, Log<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
explicit Log(ArgT x) :
|
explicit Log(ArgT&& x) :
|
||||||
m_x(std::move(x))
|
m_x(std::forward<ArgT>(x))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -448,7 +455,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArgT m_x;
|
StoreValueOrRef<ArgT> m_x;
|
||||||
|
|
||||||
T value_(T x) const
|
T value_(T x) const
|
||||||
{
|
{
|
||||||
@@ -461,10 +468,10 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||||
auto log(ArgT x)
|
auto log(ArgT&& x)
|
||||||
{
|
{
|
||||||
return Log(std::move(x));
|
return Log<ArgT&&>(std::forward<ArgT>(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user