パトリシア木

C++ で競プロ用で良い感じの実装をしてる人が見つからなかったのでここに置いておきます ( が、使用頻度はかなり少ないと思います ) 。

アルゴリズム概要

ja.wikipedia.org

トライ木の一種。
通常のトライ木は単一の文字によってラベリングされているが、パトリシア木は部分文字列によってラベリングされているため、メモリ使用量や探索時間を低減させることができる。

実装方針

ベースとなるトライ木はこちらを参考にした。

トライ木(Trie) | Luzhiled’s memo

追加すべき項目は、

  • 各ノードに対応する文字列の情報を格納する
  • ノードを分割する

である。


ノード分割は、例えば [ patricia ] を [ pat ] [ ricia ] に分割する際の方針として、

  • ノード [ pat ] を新しく作成し、既存ノードを [ ricia ] に変更する
  • 既存ノードを [ pat ] に変更し、ノード [ ricia ] を新しく作成する

の 2 通りが考えられるが、後者の場合は子ノードの情報を書き換える必要があり定数倍が重くなるため、前者の方針で実装した。


文字列情報の型は string でも良いが、今回はノード分割との兼ね合いで deque を用いた。
既存ノードの文字列情報を [ patricia ] から [ ricia ] に変更する際に、deque なら定数時間で前方要素を削除できる。 一方 string だと、例示したような短い文字列なら問題ないが、長い文字列になると書き換えコストがかなり大きくなってしまう。


なお、文字列を削除する操作は使う機会がほぼなさそうなので、実装していない。

実装

#include <bits/stdc++.h>
using namespace std;

template<int char_size>
struct PatriciaNode {
  deque<char> prefix;
  int nxt[char_size];

  int exist;
  vector<int> accept;

  PatriciaNode() : exist(0) {
    memset(nxt, -1, sizeof(nxt));
  }

  PatriciaNode(const string &str, int str_index_left, int str_index_right = -1) : exist(0) {
    memset(nxt, -1, sizeof(nxt));
    if(str_index_right == -1) str_index_right = (int)str.size();
    for(int i = str_index_left; i < str_index_right; ++i) {
      prefix.emplace_back(str[i]);
    }
  }

  int sum() const {
    return exist + (int)accept.size();
  }
};

template<int char_size, int margin>
struct PatriciaTree {
  using Node = PatriciaNode<char_size>;

  vector<Node> nodes;
  int root;

  PatriciaTree() : root(0) {
    nodes.emplace_back();
  }

  void update_direct(int node_index, int id) {
    nodes[node_index].accept.emplace_back(id);
  }

  void update_child(int node_index, int child_index, int id) {
    ++nodes[node_index].exist;
  }

  void split(const string &str, int str_index, int node_index, int nxt_index, int len) {
    int insert_index = (int)nodes.size();
    nodes.emplace_back(str, str_index, str_index + len);
    for(int i = 0; i < len; ++i) {
      nodes[nxt_index].prefix.pop_front();
    }
    nodes[node_index].nxt[str[str_index] - margin] = insert_index;
    nodes[insert_index].nxt[nodes[nxt_index].prefix[0] - margin] = nxt_index;
    nodes[insert_index].exist = nodes[nxt_index].sum();
  }

  void add(const string &str, int str_index, int node_index, int id) {
    if(str_index == (int)str.size()) {
      update_direct(node_index, id);
    } else {
      const int c = str[str_index] - margin;
      if(nodes[node_index].nxt[c] == -1) {
        nodes[node_index].nxt[c] = (int)nodes.size();
        nodes.emplace_back(str, str_index);
      } else {
        int nxt_index = nodes[node_index].nxt[c];
        deque<char> &prefix = nodes[nxt_index].prefix;
        for(int i = 0; i < (int)prefix.size(); ++i) {
          if(str_index + i == (int)str.size() || str[str_index + i] != prefix[i]) {
            split(str, str_index, node_index, nxt_index, i);
            break;
          }
        }
      }
      int nxt_index = nodes[node_index].nxt[c];
      add(str, str_index + (int)nodes[nxt_index].prefix.size(), nxt_index, id);
      update_child(node_index, nxt_index, id);
    }
  }

  void add(const string &str, int id) {
    add(str, 0, 0, id);
  }

  void add(const string &str) {
    add(str, nodes[0].exist);
  }

  void query(const string &str, const function<bool(int, int)> &f, int str_index, int node_index) {
    if(f(str_index, node_index)) return;
    if(str_index == (int)str.size()) {
      return;
    } else {
      int nxt_index = nodes[node_index].nxt[str[str_index] - margin];
      if(nxt_index == -1) return;
      query(str, f, str_index + (int)nodes[nxt_index].prefix.size(), nxt_index);
    }
  }

  void query(const string &str, const function<bool(int, int)> &f) {
    query(str, f, 0, 0);
  }

  int count() const {
    return nodes[0].exist;
  }

  int size() const {
    return (int)nodes.size();
  }
};

注意点

query 操作を行う際に引数となる文字列は既に挿入された文字列であることが前提となっているため、各ノードに格納された文字列情報との比較は先頭 1 文字分しか保証していない。
正確に比較したい場合は、引数となる function 内で比較してほしい ( 一致しない場合は function の返り値を true にすることで query が終了する ) 。

また、query 1 回あたりの最悪計算量は、正確な比較を行わない場合でも O ( f * sqrt N) となり、かなり重たい。
( f : 引数となる function の計算量、N : 挿入した文字列の長さの和 )

検証

atcoder.jp