添字and,or畳み込み
概要
添字andとorの畳み込みの計算方法について勉強したことを書きます。 勉強したことと書いてある通り、自分自身が勉強したばかりなので誤り等あるかもしれません。
仕組み
まず畳み込みの定義をします。
数列が与えられた時に
で定義される数列がそれぞれ、添字andと添字orの畳み込みです。 添字に関してandとorで畳み込んでいますね。 まずandの畳み込みについて考えます。andの畳み込みができてしまえば、orの畳み込みもほぼ同じ方法でできます。 この計算をするために次のような数列の変換を考えます。
このとき数列の変換を行うことによって、
となります。つまりもともと畳み込みの計算だったものが各点積という簡単な計算に変化しています。
(1)の証明のようなもの
2行目から3行目が一番わかりづらいです。からはを完全に含んでいるので、 もを完全に含んでいるので、えいっとシグマの取り方を変えられます。
fの逆変換
ところでこの計算で求めることができたのは、求めたかったではなくです。 もしの逆変換のようなものがあればからを得ることができます。実はそのような変換は存在します!!
次の変換はの逆変換になっています。
popcount(i)はiの立っているビットの数を表しています。つまりiを立っているビットが表す集合とみたときの集合の要素と等しいです。 ちなみに先程のと異なる部分は-1の累乗をかけている部分だけです。
gとfが互いに逆変換であることの証明のようなもの
を示せばいいはず。
二つのシグマの条件式をみるとはビットの集合の包含関係としてとなっていることがわかります。
[1] のとき
しかないからの係数はです。
[2] のとき
であるについて の総和をとったものがの係数になります。 これは包除原理でよくあるやつで、とすると、
が成り立つのでの係数はになります。
よってが成り立つことが示せました。これの逆も同じ感じで示せます。
添字and畳み込み
まとめると次の計算をするとandの畳み込みが計算できます。
- 数列に変換をかける。
- それらの数列の各点積を取る。
- できた数列にの逆変換をかける。
ちなみにこのには名前がついています。をゼータ変換、をメビウス変換といいます。 ちなみにゼータ変換には上位集合と下位集合のゼータ変換があって、この変換は上位集合のゼータ変換です。 メビウス変換も同様でこれは上位集合のメビウス変換です。
下のようなゼータ変換やメビウス変換をよく見ると思うんですが、
これは最初に紹介した変換のゼータ変換の添字をビットが表す集合と思うと同じになります。
添字or畳み込み
ここまでandの畳み込みしか書いてきていませんがorの畳み込みも同様にできます。 じつは先程の上位集合のゼータ変換、メビウス変換を下位集合ののゼータ変換、メビウス変換に変えるとorの畳み込みになります。 下位集合のゼータ変換、メビウス変換は次のような式で表されます。
証明もほぼ同じ感じでできると思います。 一応集合の形でも書いておきます。
高速ゼータ変換、メビウス変換
ゼータ変換やメビウス変換を愚直にやると数列の長さをNとしてかかります。 これをで行うのが高速ゼータ変換、メビウス変換です。 高速ゼータ変換、メビウス変換ができればandとorの畳み込みはで計算することができます。 説明は省略しますが、下のようなコードで実現できます。bitdpのようなことをやっているらしい。それぞれのbitについて累積和をしているといってもいいかも。 配列の長さは2の冪乗にしておくといいと思います。
コード
間違っている可能性があります。
上位集合のゼータ変換
void zeta_sup(vector<int> &f) { int n = (int)f.size(); for (int i = 1;i < n;i <<= 1) { for (int j = 0;j < n; j++) { if (!(j&i)) f[j] += f[j|i]; } } }
上位集合のメビウス変換
void mebius_sup(vector<int> &f) { int n = (int)f.size(); for (int i = 1;i < n;i <<= 1) { for (int j = 0;j < n; j++) { if (!(j&i)) f[j] -= f[j|i]; } } }
下位集合のゼータ変換
void zeta_sub(vector<int> &f) { int n = (int)f.size(); for (int i = 1;i < n;i <<= 1) { for (int j = 0;j < n; j++) { if (j&i) f[j] += f[j^i]; } } }
下位集合のメビウス変換
void mebius_sub(vector<int> &f) { int n = (int)f.size(); for (int i = 1;i < n;i <<= 1) { for (int j = 0;j < n; j++) { if (j&i) f[j] -= f[j^i]; } } }
andの畳み込み
vector<int> and_convolution(vector<int> a,vector<int> b) { assert(a.size() == b.size()); zeta_sup(a); zeta_sup(b); int n = (int)a.size(); for(int i = 0;i < n; i++) a[i] *= b[i]; mebius_sup(a); return a; }
orの畳み込み
vector<int> or_convolution(vector<int> a,vector<int> b) { assert(a.size() == b.size()); zeta_sub(a); zeta_sub(b); int n = (int)a.size(); for(int i = 0;i < n; i++) a[i] *= b[i]; mebius_sub(a); return a; }