自前ライブラリ自動挿入 for 競プロ

kyokkounoite.hatenablog.jp

前回の続き。テンプレート化したアルゴリズムたちを毎回コピペするのは面倒なので自動化した。


target


environment example

筆者は Visual Studio Code の Workspace Folder 以下のディレクトリを(無関係なものを除いて)以下のようにしている。

.
├── atcoder-workspace
│   ├── arc165
│   :   ├── A
│       :   └── main.cpp
│
├── library
│   ├── .include_base.hpp
│   ├── edge.hpp
│   ├── minimum_spanning_tree_kruskal.hpp
│   ├── union_find.hpp
│   :
│
└── library_inserter.py


code of library inserter

処理はそこまで重くないので Python で書いてみた。

以下のコードは自己責任の範疇において自由に複製・利用・改変してもらって構わない。

# library_inserter.py

import sys
import re


# line number of including ".include_base.hpp"
INCLUDE_BASE_LINE = 0

# maximum line number required to include other dependent libraries
MAX_INCLUDE_LINES = 10

# libraries are inserted above this expression in "main.cpp"
MAIN_LIBRARY_END = r'<templates end>'

# whether to remove a blank line directly above the line in which other dependent libraries are included
REMOVE_EMPTY_LINE = True


def recursive_inserter(main_data: list, library_path: str, already_inserted: set, insert_line: int) -> int:
    with open(library_path) as library_file:
        cnt = 0

        for line in library_file:
            body = re.fullmatch(r'#define [A-Z_]+HPP', line.strip())

            if body is not None:
                defined = body.group().replace(r'#define ', '')

                if defined in already_inserted:
                    return insert_line

                break

            cnt += 1
            if cnt >= MAX_INCLUDE_LINES:
                return insert_line

        already_inserted.add(defined)

        library_file.seek(0)
        skip = {INCLUDE_BASE_LINE}
        cnt = 0
        empty_line_above = False

        for line in library_file:
            if cnt != INCLUDE_BASE_LINE:
                body = re.match(r'#include "[a-z_/]+\.hpp"', line.strip())

                if body is not None:
                    insert_line = recursive_inserter(
                        main_data, body.group().replace(r'#include ', '').strip('"'), already_inserted, insert_line)

                    skip.add(cnt)
                    if REMOVE_EMPTY_LINE and empty_line_above:
                        skip.add(cnt - 1)

            cnt += 1
            if cnt >= MAX_INCLUDE_LINES:
                break
            empty_line_above = (line.strip() == '')

        library_file.seek(0)
        cnt = 0

        for line in library_file:
            if cnt not in skip:
                main_data.insert(insert_line, line)
                insert_line += 1

            cnt += 1

        main_data.insert(insert_line, '\n')
        main_data.insert(insert_line + 1, '\n')
        return insert_line + 2


def main(main_path: str, library_path: str):
    already_inserted = set()

    with open(main_path) as main_file:
        main_data = main_file.readlines()

    insert_line = 0

    for data in main_data:
        body = re.fullmatch(r'#define [A-Z_]+_HPP', data.strip())

        if body is not None:
            already_inserted.add(body.group().replace(r'#define ', ''))
        elif re.search(MAIN_LIBRARY_END, data) is not None:
            break

        insert_line += 1

    print(recursive_inserter(main_data, library_path, already_inserted, insert_line) - insert_line)

    with open(main_path, 'w', newline='') as main_file:
        main_file.writelines(main_data)


if __name__ == '__main__':
    if len(sys.argv) == 3:
        main(sys.argv[1], sys.argv[2])
    else:
        print('"library_inserter.py" failed due to the wrong parameter.', file=sys.stderr)
        print('The correct format is: $ library_inserter.py <main_path> <library_path>', file=sys.stderr)
        sys.exit()

細かな挙動については上記コードを参照のこと。


how to use

当然ながら path は環境に合わせて適宜変更すること。

.include_base.hpp

"main.cpp" には記述済みなので挿入したくないが、各ライブラリのメンテナンス性のために記述しておきたい部分。

#ifndef _INCLUDE_BASE_HPP
#define _INCLUDE_BASE_HPP
#include <bits/stdc++.h>
using namespace std;
constexpr int MOD = 998244353;
#endif  // _INCLUDE_BASE_HPP
[[ library_name ]].hpp

以下を順不同でヘッダファイル上部に記述する。

  • ".include_base.hpp" のインクルード
    • 挿入行 (0-indexed) を INCLUDE_BASE_LINE ("library_inserter.py") に代入する
  • インクルードガード
  • 他の依存ライブラリのインクルード

これらの行数が MAX_INCLUDE_LINES ("library_inserter.py") 以内になるように調整する。

// minimum_spanning_tree_kruskal.hpp

#include ".include_base.hpp"
#ifndef MINIMUM_SPANNING_TREE_KRUSKAL_HPP
#define MINIMUM_SPANNING_TREE_KRUSKAL_HPP

#include "edge.hpp"
#include "union_find.hpp"

// if G is disconnected, return -1
template<typename T>
T kruskal(Edges<T>& edges, int v) noexcept(false) {
  sort(edges.begin(), edges.end());
  UnionFind tree(v);
  T ret = 0;
  for(auto& e : edges) {
    if(tree.unite(e.from, e.to)) ret += e.cost;
  }
  if(tree.size(0) != v) throw -1;
  return ret;
}

#endif  // MINIMUM_SPANNING_TREE_KRUSKAL_HPP
main.cpp

MAIN_LIBRARY_END ("library_inserter.py") に格納した文字列を含めること。この文字列は挿入ライブラリとコード本体の境界線となる。

tasks.json

本当は library フォルダ内のファイル一覧を動的に取得して Visual Studio Code に反映させたかったが、結構な手間がかかりそうだったので諦めた。ライブラリを追加・削除する頻度はかなり低いはずなので大した問題ではなさそう。

{
  "tasks": [
    {
      "label": "insert library file",
      "type": "shell",
      "command": "python ../library_inserter.py '${file}' '${input:library_filename}'",
      "presentation": {
        "echo": false,
        "reveal": "silent",
        "focus": false,
        "panel": "shared",
        "showReuseMessage": false,
        "clear": false,
        "close": false
      },
      "options": {
        "cwd": "${workspaceFolder}/library"
      }
    }
  ],
  "inputs": [
    {
      "id": "library_filename",
      "type": "pickString",
      "description": "enter the template_filename",
      "options": [
        "edge.hpp",
        "minimum_spanning_tree_kruskal.hpp",
        "union_find.hpp",
        ...

      ],
      "default": ""
    }
  ]
}
keybindings.json

前回同様、Ctrl+K Ctrl+S などから開く。

[
  {
    "key": "Ctrl+F7",
    "command": "workbench.action.tasks.runTask",
    "when": "editorTextFocus && resourcePath =~ /^.*\/atcoder-workspace\/.*\/main.cpp$/",
    "args": "insert library file"
  }
]


winning run...

スクロールはされないので矢印キーで誤魔化すなど。

example of main.cpp with minimum_spanning_tree_kruskal.hpp here

#include <bits/stdc++.h>
using namespace std;

#define ALL(a) begin(a), end(a)
#define RALL(a) rbegin(a), rend(a)
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
template<typename T> using Graph = vector<vector<T>>;
template<typename T> using Spacial = vector<vector<vector<T>>>;
template<typename T> using greater_priority_queue = priority_queue<T, vector<T>, greater<T>>;
constexpr int MOD = 10;
const int dx[4] = { 1, 0, -1, 0 };
const int dy[4] = { 0, 1, 0, -1 };
char interval[2] = {' ', '\n'};

template<typename T, typename... Args> auto make_vector(T x, int arg, Args... args) { if constexpr(sizeof...(args) == 0) return vector<T>(arg, x); else return vector(arg, make_vector<T>(x, args...)); }

template<typename T> struct is_plural : false_type{};
template<typename T1, typename T2> struct is_plural<pair<T1, T2>> : true_type{};
template<typename T> struct is_plural<vector<T>> : true_type{};
template<typename T> struct is_plural<complex<T>> : true_type{};
template<> struct is_plural<string> : true_type{};

template<typename T1, typename T2> istream& operator>>(istream& is, pair<T1, T2>& p) { return is >> p.first >> p.second; }
template<typename T1, typename T2> ostream& operator<<(ostream& os, const pair<T1, T2>& p) { return os << p.first << ' ' << p.second; }
template<typename T> istream& operator>>(istream& is, vector<T>& vec) { for(auto itr = vec.begin(); itr != vec.end(); ++itr) is >> *itr; return is; }
template<typename T> ostream& operator<<(ostream& os, const vector<T>& vec) { if(vec.empty()) return os; bool pl = is_plural<T>(); os << vec.front(); for(auto itr = ++vec.begin(); itr != vec.end(); ++itr) os << interval[pl] << *itr; return os; }
template<typename T> istream& operator>>(istream& is, complex<T>& x) { T a, b; is >> a >> b; x = complex<T>(a, b); return is; }
template<typename T> ostream& operator<<(ostream& os, const complex<T>& x) { return os << x.real() << ' ' << x.imag(); }

bool CoutYN(bool a, string yes = "Yes", string no = "No") { cout << (a ? yes : no) << '\n'; return a; }

template<typename T1, typename T2> inline bool chmax(T1& a, T2 b) { return a < b && (a = b, true); }
template<typename T1, typename T2> inline bool chmin(T1& a, T2 b) { return a > b && (a = b, true); }

template<typename... Args> void debugger(int, const char*, const Args&...);
#define debug(...) debugger(__LINE__, #__VA_ARGS__, __VA_ARGS__)


/* -------- <insert libraries below> -------- */


#ifndef EDGE_HPP
#define EDGE_HPP

template<typename T>
struct edge {
  int from, to;
  T cost;

  edge(int to, T cost) : from(-1), to(to), cost(cost) {}

  edge(int from, int to, T cost) : from(from), to(to), cost(cost) {}

  operator int() const { return to; }

  bool operator<(const edge<T>& e) const {
    return cost < e.cost;
  }

  bool operator>(const edge<T>& e) const {
    return cost > e.cost;
  }

  bool operator==(const edge<T>& e) const {
    return from == e.from && to == e.to && cost == e.cost;
  }

  bool operator!=(const edge<T>& e) const {
    return !((*this) == e);
  }
};

template<typename T> using Edges = vector<edge<T>>;
template<typename T> using WeightedGraph = vector<vector<edge<T>>>;

#endif  // EDGE_HPP


#ifndef UNION_FIND_HPP
#define UNION_FIND_HPP

struct UnionFind {
  vector<int> data;

  UnionFind(int sz) {
    data.assign(sz, -1);
  }

  bool unite(int x, int y) {
    x = find(x), y = find(y);
    if(x == y) return false;
    if(data[x] > data[y]) swap(x, y);
    data[x] += data[y];
    data[y] = x;
    return true;
  }

  int find(int k) {
    if(data[k] < 0) return k;
    return (data[k] = find(data[k]));
  }

  int size(int k) {
    return -data[find(k)];
  }
};

#endif  // UNION_FIND_HPP


#ifndef MINIMUM_SPANNING_TREE_KRUSKAL_HPP
#define MINIMUM_SPANNING_TREE_KRUSKAL_HPP

// if G is disconnected, return -1
template<typename T>
T kruskal(Edges<T>& edges, int v) noexcept(false) {
  sort(edges.begin(), edges.end());
  UnionFind tree(v);
  T ret = 0;
  for(auto& e : edges) {
    if(tree.unite(e.from, e.to)) ret += e.cost;
  }
  if(tree.size(0) != v) throw -1;
  return ret;
}

#endif  // MINIMUM_SPANNING_TREE_KRUSKAL_HPP


/* -------- <templates end> -------- */


void solve() {
  
}


/* -------- <programs end> -------- */


#define DEBUG
#ifdef DEBUG
void dbg() { cerr << '\n'; }
template<typename T, typename... Args> void dbg(const T& x, const Args&... args) { cerr << '\n' << x; dbg(args...); }
template<typename... Args> void debugger(int line, const char* str, const Args&... args) { cerr << line << " [" << str << "]:"; dbg(args...); };
#else
template<typename... Args> void debugger(int line, const char* str, const Args&... args) {};
#endif

#define COMPLEX_COMPARE
#ifdef COMPLEX_COMPARE
namespace std { template<typename T> bool operator<(const complex<T>& l, const complex<T>& r) { return real(l) != real(r) ? real(l) < real(r) : imag(l) < imag(r); } }
#endif

signed main() {
  cin.tie(nullptr);
  ios::sync_with_stdio(false);
  cout << fixed << setprecision(10);
  solve();
  return 0;
}