mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 13:17:44 +00:00
More utility in autograd.
This commit is contained in:
+59
-9
@@ -7,6 +7,44 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
|
namespace Learner
|
||||||
|
{
|
||||||
|
template <typename T>
|
||||||
|
struct ValueWithGrad
|
||||||
|
{
|
||||||
|
T value;
|
||||||
|
T grad;
|
||||||
|
|
||||||
|
ValueWithGrad& operator+=(const ValueWithGrad<T>& rhs)
|
||||||
|
{
|
||||||
|
value += rhs.value;
|
||||||
|
grad += rhs.grad;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueWithGrad& operator-=(const ValueWithGrad<T>& rhs)
|
||||||
|
{
|
||||||
|
value -= rhs.value;
|
||||||
|
grad -= rhs.grad;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueWithGrad& operator*=(T rhs)
|
||||||
|
{
|
||||||
|
value *= rhs;
|
||||||
|
grad *= rhs;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueWithGrad& operator/=(T rhs)
|
||||||
|
{
|
||||||
|
value /= rhs;
|
||||||
|
grad /= rhs;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
namespace Learner::Autograd::UnivariateStatic
|
namespace Learner::Autograd::UnivariateStatic
|
||||||
{
|
{
|
||||||
|
|
||||||
@@ -19,8 +57,20 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
using Id = typename Identity<T>::type;
|
using Id = typename Identity<T>::type;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Evaluable
|
||||||
|
{
|
||||||
|
template <typename... ArgsTs>
|
||||||
|
auto eval(const std::tuple<ArgsTs...>& args) const
|
||||||
|
{
|
||||||
|
using ValueType = typename T::ValueType;
|
||||||
|
const T* this_ = static_cast<const T*>(this);
|
||||||
|
return ValueWithGrad<ValueType>{ this_->value(args), this_->grad(args) };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int I>
|
template <typename T, int I>
|
||||||
struct VariableParameter
|
struct VariableParameter : Evaluable<VariableParameter<T, I>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -42,7 +92,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int I>
|
template <typename T, int I>
|
||||||
struct ConstantParameter
|
struct ConstantParameter : Evaluable<ConstantParameter<T, I>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -64,7 +114,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Constant
|
struct Constant : Evaluable<Constant<T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -90,7 +140,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||||
struct Sum
|
struct Sum : Evaluable<Sum<LhsT, RhsT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -136,7 +186,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||||
struct Difference
|
struct Difference : Evaluable<Difference<LhsT, RhsT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -182,7 +232,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||||
struct Product
|
struct Product : Evaluable<Product<LhsT, RhsT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -228,7 +278,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
struct Sigmoid
|
struct Sigmoid : Evaluable<Sigmoid<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -270,7 +320,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
struct Pow
|
struct Pow : Evaluable<Pow<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
@@ -304,7 +354,7 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||||
struct Log
|
struct Log : Evaluable<Log<ArgT, T>>
|
||||||
{
|
{
|
||||||
using ValueType = T;
|
using ValueType = T;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user