When forming an autograd expression only copy parts that are rvalue references, store references to lvalues.

This commit is contained in:
Tomasz Sobczyk
2020-11-29 19:06:31 +01:00
committed by nodchip
parent a5c20bee5b
commit aec6017195
+80 -73
View File
@@ -69,6 +69,13 @@ namespace Learner::Autograd::UnivariateStatic
template <typename T> template <typename T>
using Id = typename Identity<T>::type; using Id = typename Identity<T>::type;
template <typename T>
using StoreValueOrRef = std::conditional_t<
std::is_rvalue_reference_v<T>,
std::remove_reference_t<T>,
const std::remove_reference_t<T>&
>;
template <typename T, typename ChildT> template <typename T, typename ChildT>
struct Evaluable struct Evaluable
{ {
@@ -179,14 +186,14 @@ namespace Learner::Autograd::UnivariateStatic
T m_x; T m_x;
}; };
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>> struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>>
{ {
using ValueType = T; using ValueType = T;
Sum(LhsT lhs, RhsT rhs) : Sum(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::move(lhs)), m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::move(rhs)) m_rhs(std::forward<RhsT>(rhs))
{ {
} }
@@ -203,36 +210,36 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
LhsT m_lhs; StoreValueOrRef<LhsT> m_lhs;
RhsT m_rhs; StoreValueOrRef<RhsT> m_rhs;
}; };
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator+(LhsT lhs, RhsT rhs) auto operator+(LhsT&& lhs, RhsT&& rhs)
{ {
return Sum(std::move(lhs), std::move(rhs)); return Sum<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
} }
template <typename LhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator+(LhsT lhs, Id<T> rhs) auto operator+(LhsT&& lhs, Id<T> rhs)
{ {
return Sum(std::move(lhs), Constant(rhs)); return Sum<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
} }
template <typename RhsT, typename T = typename RhsT::ValueType> template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
auto operator+(Id<T> lhs, RhsT rhs) auto operator+(Id<T> lhs, RhsT&& rhs)
{ {
return Sum(Constant(lhs), std::move(rhs)); return Sum<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
} }
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>> struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>>
{ {
using ValueType = T; using ValueType = T;
Difference(LhsT lhs, RhsT rhs) : Difference(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::move(lhs)), m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::move(rhs)) m_rhs(std::forward<RhsT>(rhs))
{ {
} }
@@ -249,36 +256,36 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
LhsT m_lhs; StoreValueOrRef<LhsT> m_lhs;
RhsT m_rhs; StoreValueOrRef<RhsT> m_rhs;
}; };
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator-(LhsT lhs, RhsT rhs) auto operator-(LhsT&& lhs, RhsT&& rhs)
{ {
return Difference(std::move(lhs), std::move(rhs)); return Difference<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
} }
template <typename LhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator-(LhsT lhs, Id<T> rhs) auto operator-(LhsT&& lhs, Id<T> rhs)
{ {
return Difference(std::move(lhs), Constant(rhs)); return Difference<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
} }
template <typename RhsT, typename T = typename RhsT::ValueType> template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
auto operator-(Id<T> lhs, RhsT rhs) auto operator-(Id<T> lhs, RhsT&& rhs)
{ {
return Difference(Constant(lhs), std::move(rhs)); return Difference<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
} }
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
struct Product : Evaluable<T, Product<LhsT, RhsT, T>> struct Product : Evaluable<T, Product<LhsT, RhsT, T>>
{ {
using ValueType = T; using ValueType = T;
Product(LhsT lhs, RhsT rhs) : Product(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::move(lhs)), m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::move(rhs)) m_rhs(std::forward<RhsT>(rhs))
{ {
} }
@@ -295,35 +302,35 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
LhsT m_lhs; StoreValueOrRef<LhsT> m_lhs;
RhsT m_rhs; StoreValueOrRef<RhsT> m_rhs;
}; };
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator*(LhsT lhs, RhsT rhs) auto operator*(LhsT&& lhs, RhsT&& rhs)
{ {
return Product(std::move(lhs), std::move(rhs)); return Product<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
} }
template <typename LhsT, typename T = typename LhsT::ValueType> template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator*(LhsT lhs, Id<T> rhs) auto operator*(LhsT&& lhs, Id<T> rhs)
{ {
return Product(std::move(lhs), Constant(rhs)); return Product<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
} }
template <typename RhsT, typename T = typename RhsT::ValueType> template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
auto operator*(Id<T> lhs, RhsT rhs) auto operator*(Id<T> lhs, RhsT&& rhs)
{ {
return Product(Constant(lhs), std::move(rhs)); return Product<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
} }
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Negation : Evaluable<T, Negation<ArgT, T>> struct Negation : Evaluable<T, Negation<ArgT, T>>
{ {
using ValueType = T; using ValueType = T;
explicit Negation(ArgT x) : explicit Negation(ArgT&& x) :
m_x(std::move(x)) m_x(std::forward<ArgT>(x))
{ {
} }
@@ -340,22 +347,22 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
ArgT m_x; StoreValueOrRef<ArgT> m_x;
}; };
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto operator-(ArgT x) auto operator-(ArgT&& x)
{ {
return Negation(std::move(x)); return Negation<ArgT&&>(std::forward<ArgT>(x));
} }
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>> struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>>
{ {
using ValueType = T; using ValueType = T;
explicit Sigmoid(ArgT x) : explicit Sigmoid(ArgT&& x) :
m_x(std::move(x)) m_x(std::forward<ArgT>(x))
{ {
} }
@@ -372,7 +379,7 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
ArgT m_x; StoreValueOrRef<ArgT> m_x;
T value_(T x) const T value_(T x) const
{ {
@@ -385,19 +392,19 @@ namespace Learner::Autograd::UnivariateStatic
} }
}; };
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto sigmoid(ArgT x) auto sigmoid(ArgT&& x)
{ {
return Sigmoid(std::move(x)); return Sigmoid<ArgT&&>(std::forward<ArgT>(x));
} }
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Pow : Evaluable<T, Pow<ArgT, T>> struct Pow : Evaluable<T, Pow<ArgT, T>>
{ {
using ValueType = T; using ValueType = T;
explicit Pow(ArgT x, Id<T> exponent) : explicit Pow(ArgT&& x, Id<T> exponent) :
m_x(std::move(x)), m_x(std::forward<ArgT>(x)),
m_exponent(std::move(exponent)) m_exponent(std::move(exponent))
{ {
} }
@@ -415,23 +422,23 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
ArgT m_x; StoreValueOrRef<ArgT> m_x;
T m_exponent; T m_exponent;
}; };
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto pow(ArgT x, Id<T> exp) auto pow(ArgT&& x, Id<T> exp)
{ {
return Pow(std::move(x), std::move(exp)); return Pow<ArgT&&>(std::forward<ArgT>(x), std::move(exp));
} }
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Log : Evaluable<T, Log<ArgT, T>> struct Log : Evaluable<T, Log<ArgT, T>>
{ {
using ValueType = T; using ValueType = T;
explicit Log(ArgT x) : explicit Log(ArgT&& x) :
m_x(std::move(x)) m_x(std::forward<ArgT>(x))
{ {
} }
@@ -448,7 +455,7 @@ namespace Learner::Autograd::UnivariateStatic
} }
private: private:
ArgT m_x; StoreValueOrRef<ArgT> m_x;
T value_(T x) const T value_(T x) const
{ {
@@ -461,10 +468,10 @@ namespace Learner::Autograd::UnivariateStatic
} }
}; };
template <typename ArgT, typename T = typename ArgT::ValueType> template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto log(ArgT x) auto log(ArgT&& x)
{ {
return Log(std::move(x)); return Log<ArgT&&>(std::forward<ArgT>(x));
} }
} }