mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 15:37:47 +00:00
Identify a single evalation chain by ID in autograd to prevent cache reuse for subsequent evaluations of the same expression tree.
This commit is contained in:
+82
-5
@@ -8,6 +8,7 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
namespace Learner
|
namespace Learner
|
||||||
{
|
{
|
||||||
@@ -76,46 +77,122 @@ namespace Learner::Autograd::UnivariateStatic
|
|||||||
const std::remove_reference_t<T>&
|
const std::remove_reference_t<T>&
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
namespace Detail
|
||||||
|
{
|
||||||
|
using CallIdType = std::uint32_t;
|
||||||
|
|
||||||
|
struct CallId
|
||||||
|
{
|
||||||
|
CallIdType call_id{};
|
||||||
|
|
||||||
|
constexpr CallId() :
|
||||||
|
call_id(0)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr CallId(CallIdType id) :
|
||||||
|
call_id(id)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool operator==(CallId rhs) const noexcept
|
||||||
|
{
|
||||||
|
return call_id == rhs.call_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool operator!=(CallId rhs) const noexcept
|
||||||
|
{
|
||||||
|
return call_id != rhs.call_id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
[[nodiscard]] inline CallId next_call_id()
|
||||||
|
{
|
||||||
|
static thread_local CallIdType s_call_id = 0;
|
||||||
|
return CallId{ s_call_id++ };
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Tuple>
|
||||||
|
struct TupleContains;
|
||||||
|
|
||||||
|
template <typename T, typename... Us>
|
||||||
|
struct TupleContains<T, std::tuple<Us...>> : std::disjunction<std::is_same<T, Us>...> {};
|
||||||
|
|
||||||
|
template <typename T, typename Tuple>
|
||||||
|
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename ChildT>
|
template <typename T, typename ChildT>
|
||||||
struct Evaluable
|
struct Evaluable
|
||||||
{
|
{
|
||||||
constexpr Evaluable() = default;
|
constexpr Evaluable() = default;
|
||||||
|
|
||||||
|
// We append a unique call id so that we can invalidate the cache when
|
||||||
|
// the next computation starts. A single evaluation should see
|
||||||
|
// the same call_id at every node.
|
||||||
template <typename... ArgsTs>
|
template <typename... ArgsTs>
|
||||||
[[nodiscard]] auto eval(const std::tuple<ArgsTs...>& args) const
|
[[nodiscard]] auto eval(const std::tuple<ArgsTs...>& args) const
|
||||||
{
|
{
|
||||||
return ValueWithGrad<T>{ value(args), grad(args) };
|
const auto call_id = Detail::next_call_id();
|
||||||
|
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
|
||||||
|
return ValueWithGrad<T>{ value(new_args), grad(new_args) };
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... ArgsTs>
|
template <typename... ArgsTs,
|
||||||
|
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||||
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args) const
|
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args) const
|
||||||
{
|
{
|
||||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||||
|
|
||||||
if (!value_cache.has_value())
|
const auto call_id = std::get<Detail::CallId>(args);
|
||||||
|
if (!value_cache.has_value() || value_cache_call_id != call_id)
|
||||||
{
|
{
|
||||||
|
value_cache_call_id = call_id;
|
||||||
value_cache = this_->calculate_value(args);
|
value_cache = this_->calculate_value(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
return *value_cache;
|
return *value_cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... ArgsTs>
|
template <typename... ArgsTs,
|
||||||
|
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||||
|
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args, ...) const
|
||||||
|
{
|
||||||
|
const auto call_id = Detail::next_call_id();
|
||||||
|
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
|
||||||
|
return value(new_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... ArgsTs,
|
||||||
|
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||||
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
|
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
|
||||||
{
|
{
|
||||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||||
|
|
||||||
if (!grad_cache.has_value())
|
const auto call_id = std::get<Detail::CallId>(args);
|
||||||
|
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
|
||||||
{
|
{
|
||||||
|
grad_cache_call_id = call_id;
|
||||||
grad_cache = this_->calculate_grad(args);
|
grad_cache = this_->calculate_grad(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
return *grad_cache;
|
return *grad_cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename... ArgsTs,
|
||||||
|
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||||
|
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args, ...) const
|
||||||
|
{
|
||||||
|
const auto call_id = Detail::next_call_id();
|
||||||
|
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
|
||||||
|
return grad(new_args);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutable std::optional<T> value_cache;
|
mutable std::optional<T> value_cache;
|
||||||
mutable std::optional<T> grad_cache;
|
mutable std::optional<T> grad_cache;
|
||||||
|
mutable Detail::CallId value_cache_call_id{};
|
||||||
|
mutable Detail::CallId grad_cache_call_id{};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int I>
|
template <typename T, int I>
|
||||||
|
|||||||
Reference in New Issue
Block a user