Identify a single evalation chain by ID in autograd to prevent cache reuse for subsequent evaluations of the same expression tree.

This commit is contained in:
Tomasz Sobczyk
2020-11-30 14:01:31 +01:00
committed by nodchip
parent cb812c742c
commit 8adf00ae6e
+82 -5
View File
@@ -8,6 +8,7 @@
#include <tuple> #include <tuple>
#include <optional> #include <optional>
#include <algorithm> #include <algorithm>
#include <cstdint>
namespace Learner namespace Learner
{ {
@@ -76,46 +77,122 @@ namespace Learner::Autograd::UnivariateStatic
const std::remove_reference_t<T>& const std::remove_reference_t<T>&
>; >;
namespace Detail
{
using CallIdType = std::uint32_t;
struct CallId
{
CallIdType call_id{};
constexpr CallId() :
call_id(0)
{
}
constexpr CallId(CallIdType id) :
call_id(id)
{
}
[[nodiscard]] bool operator==(CallId rhs) const noexcept
{
return call_id == rhs.call_id;
}
[[nodiscard]] bool operator!=(CallId rhs) const noexcept
{
return call_id != rhs.call_id;
}
};
[[nodiscard]] inline CallId next_call_id()
{
static thread_local CallIdType s_call_id = 0;
return CallId{ s_call_id++ };
}
template <typename T, typename Tuple>
struct TupleContains;
template <typename T, typename... Us>
struct TupleContains<T, std::tuple<Us...>> : std::disjunction<std::is_same<T, Us>...> {};
template <typename T, typename Tuple>
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
}
template <typename T, typename ChildT> template <typename T, typename ChildT>
struct Evaluable struct Evaluable
{ {
constexpr Evaluable() = default; constexpr Evaluable() = default;
// We append a unique call id so that we can invalidate the cache when
// the next computation starts. A single evaluation should see
// the same call_id at every node.
template <typename... ArgsTs> template <typename... ArgsTs>
[[nodiscard]] auto eval(const std::tuple<ArgsTs...>& args) const [[nodiscard]] auto eval(const std::tuple<ArgsTs...>& args) const
{ {
return ValueWithGrad<T>{ value(args), grad(args) }; const auto call_id = Detail::next_call_id();
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
return ValueWithGrad<T>{ value(new_args), grad(new_args) };
} }
template <typename... ArgsTs> template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args) const [[nodiscard]] auto value(const std::tuple<ArgsTs...>& args) const
{ {
const ChildT* this_ = static_cast<const ChildT*>(this); const ChildT* this_ = static_cast<const ChildT*>(this);
if (!value_cache.has_value()) const auto call_id = std::get<Detail::CallId>(args);
if (!value_cache.has_value() || value_cache_call_id != call_id)
{ {
value_cache_call_id = call_id;
value_cache = this_->calculate_value(args); value_cache = this_->calculate_value(args);
} }
return *value_cache; return *value_cache;
} }
template <typename... ArgsTs> template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args, ...) const
{
const auto call_id = Detail::next_call_id();
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
return value(new_args);
}
template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const [[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
{ {
const ChildT* this_ = static_cast<const ChildT*>(this); const ChildT* this_ = static_cast<const ChildT*>(this);
if (!grad_cache.has_value()) const auto call_id = std::get<Detail::CallId>(args);
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
{ {
grad_cache_call_id = call_id;
grad_cache = this_->calculate_grad(args); grad_cache = this_->calculate_grad(args);
} }
return *grad_cache; return *grad_cache;
} }
template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args, ...) const
{
const auto call_id = Detail::next_call_id();
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
return grad(new_args);
}
private: private:
mutable std::optional<T> value_cache; mutable std::optional<T> value_cache;
mutable std::optional<T> grad_cache; mutable std::optional<T> grad_cache;
mutable Detail::CallId value_cache_call_id{};
mutable Detail::CallId grad_cache_call_id{};
}; };
template <typename T, int I> template <typename T, int I>