C++泛型之柯里化、函数组合

原文:https://zhuanlan.zhihu.com/p/662698843

Reference:

C++ 实现 Currying 和 Partial application

从零开始的简单函数式C++(八)函数组合 - 知乎 (zhihu.com)

抛开柯里化和函数组合在非函数式的C++语言中能有多大用与抽象的额外开销不谈.. 本文只是单纯地解析一下如何利用C++泛型(模板)实现这两个高阶函数

一个最小实现其实不是很复杂,只要先了解一点变长模板、forwarding referencelambda就可以..

Compose

首先是函数组合(Compose

给定函数f1,f2,,fNf_1, f_2, \dots, f_N,我们希望给定输入a1,a2,aMa_1, a_2, \dots a_M,总有Compose(f1,f2,,fN)(a1aM)=fN(fN1(f2(f1(a1aM))))Compose(f_1, f_2, \dots, f_N)(a_1 \dots a_M) = f_N(f_{N-1}\dots (f_2(f_1(a_1 \dots a_M))))

也就是$Compose(f_1, f_2, \dots, f_N) \equiv f_N(f_{N-1}\dots (f_2(f_1))) $

由于$Compose(f_2, \dots, f_{N}) \equiv f_{N}(f_{N-1}\dots (f_2)) $

所以$Compose(f_1, f_2, \dots, f_N) \equiv a_1 \dots a_M \rightarrow Compose(f_2, \dots, f_{N})(f1(a_1 \dots a_M)) $

那可以看出,我们实现Compose的一种方式是写一个递归模板函数,但这样引入了递归的开销

另外考虑如果令a  f=f(a)a\ |\ f = f(a),那么

Compose(f1,f2,,fN)(a1aM)=f1(a1aM)  f2    fNCompose(f_1, f_2, \dots, f_N)(a_1 \dots a_M) = f_1(a_1\dots a_M) \ |\ f_2 \ |\ \dots\ |\ f_N

因为C++14中引入了折叠表达式,我们也可以通过operator重载 + 折叠表达式的方式实现这种非递归的链式调用与函数组合

0. 递归实现

先来看简单一点的递归版本实现

template <typename... Func> struct Compose {
private:
    using FuncChain = std::tuple<Func...>;
    FuncChain funcs;
public:
  template <typename... FuncRef>
  explicit Compose(FuncRef &&...);

  template <typename... ArgT> 
  decltype(auto) operator()(ArgT &&...);
}

我们把Compose实现为模板函数对象,模板参数为Func...,是需要组合的函数的类型列表

Compose内部,我们维护了一个tuple<Func...>,存储所有待组合的函数,因为调用组合函数时,它需要看到所有被组合的函数,所以我们才将Compose实现为能持有状态的函数对象

Compose的构造函数是template <typename... FuncRef> Compose(FuncRef &&...),使用Forwarding Refence(万能引用)区分FuncRef&&Func的左值引用还是右值引用,从而支持复制/移动语义

的这里不直接写成Compose(Func &&...)的原因是,编译器在生成Compose构造函数代码时,已经进入了Compose内,Func...类型已确定,Func&&不是万能引用,而是一个确定的类型

简单地说,万能引用的原理是引用折叠,当FuncRef = TFuncRef = T&&时,FunRef&& = T&&

FuncRef = T&时,FunRef&& = T&,了解了这些规则,就能理解这里为何还要套一层模板

最后是operator()这里返回类型声明为decltype(auto)

decltype(auto)对于编译器而言大概就是这个意思

decltype(auto) f() {
    return some_expr;
}
// ==
decltype((some_expr)) f() {
    return some_expr;
}

也就是保持返回的表达式的引用类型和cv限定符,避免Compose修改了原始函数们的返回值类型

template <typename... FuncRef>
explicit Compose(FuncRef &&...f) : funcs{std::forward<FuncRef>(f)...} {};

递归版本的话,operator()内部调用一个递归模板_invoke,令_invoke<Index, ArgT...>调用第Index个函数,然后把结果传给第Index + 1个函数_invoke<Index + 1, ArgT...>

template <std::size_t Index, typename... ArgT>
decltype(auto) _invoke(FuncChain &funcs, ArgT &&...args) {
  if constexpr (Index + 1 == sizeof...(Func))
    return std::get<Index>(funcs)(std::forward<ArgT>(args)...);
  else   
    return _invoke<Index + 1>(funcs)(std::get<Index>(funcs)(std::forward<ArgT>(args)...)); 
} 

template <typename... ArgT> decltype(auto) operator()(ArgT &&...args) {
  return _invoke<0, ArgT...>(funcs, std::forward<ArgT>(args)...);
}

如果当前是最后一个函数(Index + 1 == sizeof...(Func)),那么直接返回结果,否则递归调用

这样就完成了,但还有个小问题,我们现在的实现处理不了这种情况:中间某个函数返回了void,比如

string s{"nihao"};
auto get_ref_to_my_name = [&]() -> string & { return s; };
auto then_modify_it = [](string &sref) -> void {
  sref = "goodbye";
};
auto and_print = [&]() -> void {
  cout << "s == \"" << s << "\" now\n";
};
auto &&action1 = Compose(get_ref_to_my_name, then_modify_it, and_print);
action1();

因为C/C++void比较特殊,如果有void f(), void g(),那么f(g())是不符表达式语法的,组合fg应该用f(), g()

所以这里我们对void特判一下

decltype(auto) _invoke(FuncChain &funcs, ArgT &&...args) {
  // ...
  using CurrentResT = std::invoke_result_t<
      std::remove_reference_t<decltype(std::get<Index>(funcs))>, ArgT...>;
  if constexpr (!std::is_void_v<CurrentResT>)
    return _invoke<Index + 1, CurrentResT>(
        funcs, std::get<Index>(funcs)(std::forward<ArgT>(args)...));
  else {
    std::get<Index>(funcs)(std::forward<ArgT>(args)...);
    return _invoke<Index + 1>(funcs);
  }
} 

其中使用std::invoke_result_t让编译器告诉我们当前第Index个函数的返回值类型,如果是void,就避免把void传递到下一个函数

最后我们的Compose是一个模板类,所以写一个模板参数引导

template <typename... FuncRef>
Compose(FuncRef...) -> Compose<std::remove_reference_t<FuncRef>...>;

这里的引导的作用是帮助编译器Compose通过构造函数的形参类型推导出模板参数类型,我们构造函数输入是万能引用,存储的Func是值类型,所以用std::remove_reference_t擦除引用

auto &&action1 = Compose(get_ref_to_my_name, then_modify_it, and_print);

这样Compose就不用写模板参数了

那现在我们把它改成非递归版本的XD

1. 非递归实现

上文提到过的

a  f=f(a)a\ |\ f = f(a)

Compose(f1,f2,,fN)(a1aM)=f1(a1aM)  f2    fNCompose(f_1, f_2, \dots, f_N)(a_1 \dots a_M) = f_1(a_1\dots a_M) \ |\ f_2 \ |\ \dots\ |\ f_N

在表达能力不弱的C++中也能实现,我们这里选择管道符|重载函数调用关系,当然其他二元operator比如逗号也可以,|的话好看一点,以及Unix管道与C++ 20 ranges都是使用|表示组合关系的

首先实现operator|的重载,同时为了处理void,我们引入一个辅助的空类型None

struct None {};
template <typename Arg, typename Func>
decltype(auto) operator|(Arg &&a, Func f) {
  using Ret = std::invoke_result_t<Func, decltype(a)>;
  if constexpr (std::is_void_v<Ret>) {
    f(a);
    return None{};
  } else
    return f(a);
}
template <typename Func> decltype(auto) operator|(None, Func f) {
  using Ret = std::invoke_result_t<Func>;
  if constexpr (std::is_void_v<Ret>) {
    f();
    return None{};
  } else
    return f();
}

首先a | f返回f(a),但如果返回void的话就转而返回一个None{}

然后我们再为None重载一个operator|(第一个形参是None不是泛型,重载优先级更高),避免None{}被传入当前函数

测试一下

auto add = [](auto a, auto b) { return a + b; };
auto multiply = [](auto x) { return x * 2; };
auto square = [](auto x) { return x * x; };
cout << (add(3, 6) | multiply | square) << endl;
// 324

现在只剩实现Compose(f1,f2,,fN)(a1aM)=f1(a1aM)  f2    fNCompose(f_1, f_2, \dots, f_N)(a_1 \dots a_M) = f_1(a_1\dots a_M) \ |\ f_2 \ |\ \dots\ |\ f_N

我们可以用C++ 14折叠表达式,简单说,如果有边长模板参数...F = F2, F3 ... FN,那么(result_of_F1 | ... | F)等价于(((result_of_F1 | F2) | F3)... | FN),这样结合重载的|就实现函数组合了

template <typename ...Func> struct Compose {
private:
  // ...
  template <std::size_t... Indexes, typename... ArgT>
  decltype(auto) _invoke(FuncChain &funcs,
                         std::integer_sequence<std::size_t, Indexes...>,
                         ArgT &&...args) {
    return (std::get<0>(funcs)(std::forward<ArgT>(args)...) | ... |
            std::get<Indexes + 1>(funcs));
  };

public:
  // ...
  template <typename... ArgT> decltype(auto) operator()(ArgT &&...args) {
    return _invoke(funcs, std::make_index_sequence<sizeof...(Func) - 1>{},
                   std::forward<ArgT>(args)...);
  }
};

这里operator()_invoke传递了一个std::make_index_sequence<sizeof...(Func) - 1>,也就是std::integer_sequence<std::size_t, 0, 1, 2, sizeof...(Func) - 2>,这样_invoke内部就能解构出这些下标,有了下标我们就能索引第1(从0计数)个函数到最后一个函数了

_invoke内,我们首先用std::get<0>(funcs)(std::forward<ArgT>(args)...) 拿到第0个函数返回值,然后用折叠表达式展开operator()传入的下标,链式调用后续函数

测试代码

using namespace std;
int main() {
  auto add = [](auto a, auto b) { return a + b; };
  auto multiply = [](auto x) { return x * 2; };
  auto square = [](auto x) { return x * x; };
  cout << (add(3, 6) | multiply | square) << endl;
  cout << Compose(add, multiply, square) (3, 6)<< endl;
  // Case0: Normal func chain, output 324 324

  string s{"nihao"};
  auto get_ref_to_my_name = [&]() -> string & { return s; };
  auto then_modify_it = [](string &sref) -> void {
    sref = "goodbye";
  };
  auto and_print = [&]() -> void {
    cout << "s == \"" << s << "\" now\n";
  };
  auto &&action1 = Compose(get_ref_to_my_name, then_modify_it, and_print);
  action1();
  // Case1: compose contains void on functional chain; output:s == "goodbye" now

  auto &&meaningless = Compose([] {});
  meaningless();
  // Case2: single compose 

  auto f1 = [](int a, int b) { return a + b; };
  auto f2 = [](int x) -> double { return x * 1.16666666666666; };
  auto &&action2 = Compose(
      Compose(f1, f2), [](auto x) { return "Result: " + std::to_string(x); });
  cout << action2(-3, 5);
  // Case3: Nested compose
  // output:
  // Result: 2.333333
}

包含4种情况:普通组合、中间有void的组合、单个函数组合、嵌套组合

Full code: [Github]

Currying

接下来顺便写一个柯里化(Currying),这里涉及到了函数形参类型的解构,所以我们主要利用std::function类实现柯里化

柯里化(Currying)是一种将多参数函数转换为一系列单参数函数的技术

对于NN元函数f(x1,x2,,xN)f(x_1, x_2, \dots, x_N), 令Curry(f)Curry(f)为柯里化后的ff

那么Curry(f)(a1)(a2)(aN)=f(a1,a2,aN)Curry(f)(a_1)(a_2)\dots (a_N) = f(a_1, a_2, \dots a_N)

g=af(x1=a1,x2,xN)g = a \rightarrow f(x_1 = a_1, x_2, \dots x_N )

由于$Curry(g(a_1))(a_2)(a_3)\dots(a_N) = g(a_1)(a_1, a_2, \dots a_N) $

两式结合也就是Curry(f)(a1)(a2)(aN)=Curry(g(a1))(a2)(a3)(aN)Curry(f)(a_1)(a_2)\dots (a_N) = Curry(g(a_1))(a_2)(a_3)\dots(a_N)

所以Curry(f)=aCurry(g)=aCurry(f(x1=a,x2,,xN))Curry(f) = a \rightarrow Curry(g) = a \rightarrow Curry(f(x_1 = a, x_2, \dots ,x_N))

所以我们将Curry实现为一个递归闭包

template <typename RetT, typename... ArgsT>
auto Curry(std::function<RetT(ArgsT...)> func) {
  if constexpr (sizeof...(ArgsT) < 2)
    return func;
  else
    return _curry_combine<RetT, ArgsT...>(std::move(func));
};

创建一个模板函数Curry解构出std::function的返回值类型RetT和形参类型...ArgsT

Curry接收原始std::function,输出柯里化后的闭包,或原始函数的形参数量少于2,那就直接返回原函数,这里使用使用if constexpr条件编译

template <typename Return, typename Arg, typename... Args>
auto _curry_combine(std::function<Return(Arg, Args...)> original) {
   // f(x_1, x_2, ... x_N)
  if constexpr (sizeof...(Args) == 0)
    return original;
  else
    return [f = std::move(original)](Arg &&arg) { // arg -> curry(f(arg, x_2, x_3, ..., x_N))
      return _curry_combine(//curry(f(x_1 = arg, x_2, x_3, ..., x_N))
          std::function([&f, &arg](Args &&...args) {
       		 return f(std::forward<Arg>(arg), std::forward<Args>(args)...);
              // f((x_1 = arg, x_2, x_3, ..., x_N)
      }));
    };
}

_curry_combine用于生成递归闭包,每一层闭包存储当前的形参arg与上一层闭包的引用f即可

最后我们处理的都是std::function,但std::function能从C++函数或函数指针或函数对象构造

template <typename FuncLike>
auto Curry(FuncLike&& func) {
  return Curry(std::function{std::forward<FuncLike>(func)});
};

所以再重载一个Curry支持其他的可调用对象,构造相应的function

测试代码

auto f(int a, int b, int c, int d, int e, int f) {
  return a + b + c + d + e + f;
}
auto f2(int& ref, int val) { ref = val; }
int main() {
  auto cf1 = Curry(f);
  std::cout << cf1(1)(1)(4)(5)(1)(4)<< std::endl;
  // Case0: normal
  // 16
  int what;
  auto cf2 = Curry(f2);
  auto &&setter = cf2(what);
  setter(107);
  std::cout << what << std::endl;
  //Case1: Delay Invoke
  //107
    
  auto cf3 = Curry([] (std::string const &a, std::string const & b) {return a + b;});
  std::cout << cf3(std::string{"hola"})(std::string{" amigo"}) << std::endl;
  // Case2: lambda
  // hola amigo
}

以上函数组合与柯里化的完整实现代码:[compose.hpp, curry.hpp]