【苦しみながら理解する強化学習】A3Cの実装を覗く

いつもながら sugulu さんの実装を読んでいきます。
この人は頭いいし、謙虚そうだし完璧なのか。

A3C

しかしながらA3Cは、DQNの次の世代的存在であるため、DQNからの変化幅が大きく、理解するのがなかなか難しいです。

また、figたちに関してもめちゃくちゃわかりやすい。
ペーパーを読んでまず箇条書きやfigに落とし込む。
その後、実装に取り組むという流れが論文実装の方法なのかと察した。
そして、実装中もメモをとってclass分けなどをすると、fig3のようなものも書けそうだ。

fig3

論文実装の時は以下を心がける。

  • メモや図を取りながら論文を読む。
  • 実装内容を箇条書きしてまとめる。
  • 実装中もclass構成などのメモをとる。

しかし、A3Cのペーパー読んでこんなことができるか一度やった方がよさそうだ…
今回は解説を読みながら自分なりにメモしていこう。

実装内容

DQNに比べて更に複雑になってきたので、分類を更に分けたい。

  • 行動の価値を予測する実装
    • 予測用NN -> Actor-Critic, Advantage
    • 予測価値の算出 -> Actor-Critic, Advantage
    • 予測用NNの更新処理方法 -> Asynchronous
  • 実行した行動から報酬を計算する実装 -> 今回特に変化

(だいたいDQNからの発展と捉えています。)

行動の価値を予測する実装

予測用のNN

suguluさんの以下の図がわかりやすいのですが、

fig3

何本もthreadを走らせて重み学習させ、その後大元の重み(図の中ではParameter Server)を更新するという手法をとります。
そのため、構成は一緒ですがParameter ServerとThreadでは微妙にことなります。

  • ParameterServerクラス

  • LocalBrainクラス(Thread)

下から2行目が異なるだけでした。
plot_modelはコメントにもありますが、可視化するだけのもの。
_make_predict_function()とは、

普通モデルが実行されるとき、「GPU上でのモデルのビルドとコンパイル」そして「実行」という2段階の処理が行われる。
以下のページの議論によると、 実は普通にmodel.predict()をいきなり呼ぶと、この2段階の処理が毎回実行されて、めちゃくちゃ遅くなっている のだという。
そこで、 _make_predict_function()を先に実行しておくと、predictするときに「GPU上でのモデルのビルドとコンパイル」が行われず、単に「実行」だけできるようになり、処理が高速化される らしい。

  • 重みの学習

この部分はのちのち、sess.runでfeed_dictとして与えられる値のplace_holderですね。
そこから、モデルを使って、actor_critic(行動の価値が高くなる確率p)とvalue(状態価値v)が返される。

ここを理解するにはまず方策勾配定理を理解する必要がある。

get_collectionでscopeの変数を集める。
この場合はnameのTRAINABLE_VARIABLESだった。

gradientsでは、self.loss_totalをself.weights_paramsで微分した勾配が入る。

予測用NNの更新処理方法

Agentを介して、LocalBrainを使って並列処理してParameterServerを更新するようにしている。

  • ParameterServer(1 : n)
    • WorkerThread – Environment – Agent – LocalBrain
    • WorkerThread – Environment – Agent – LocalBrain
    • WorkerThread – Environment – Agent – LocalBrain
    • WorkerThread – Environment – Agent – LocalBrain
    • WorkerThread – Environment – Agent – LocalBrain

また手順は非常にわかりやすくまとめてくれています。

  1. スレッドはParameter Serverからネットワークの重みをコピーする
  2. スレッドのAgentは自分のネットワークにsを入力して、aを得る
  3. aを実行し、r(t)とs_を得る
  4. (s,a, r, s_)をスレッドのメモリに格納する
  5. 2〜4を繰り返す(各スレッドでTmaxステップ経過もしくは、終端に達するまで)
  6. 経験が十分に溜まったら、自分スレッドのメモリの内容を利用して、ネットワークの重みを更新させる方向gradを求める
  7. gradをParameter Serverに渡す
  8. Parameter Serverはgradの方向にParameter Serverのネットワークを更新する
  9. 1.へ戻る

ここは更新するところ。
apply_gradientsでlistを渡すことによって、普通にvariableを更新できる。
こういうやり方をちょっとはじめてみたんですが、update_global_weight_paramsにapply_gradientsを入れておいて、以下のように実行して更新する。

これを使っているメソッドはこちら。

pull_global_weight_paramspush_local_weight_paramsは更新するメソッドにしておく。

僕だけですが、tensorflowはどこでどうやって実行しているかわかりにくい。

ここにしっかりと準備と書いてくれているので非常にわかりやすいですが、
まずthreadを準備してからworker.run()以下でで実行します。

thread数は本記事ではN_WORKERS=8となっているので8つ。
今回実際にメソッドを呼び出しているのは、Environmentクラスになります。

whileで平均報酬が一定を超えるまでLocalBrainで学習をさせています。
終わるかどうかは、isLearnedで管理していて、親クラスのWorkerThreadでTrueになるまで学習を続けさせます。
gymからdone=Trueが返ってきた時はログを吐き出させて一旦ループから抜けますが、まだThreadとして学習は終わっていないので再度ループを回します。

doneになるかTmax回ごとにParameterServerの値を更新します。
そして、最新のものを使うようにします。

total_reward_vecは直近10回の報酬で、WorkerThreadの学習をやめるかどうかを判定する。

予測した報酬の算出

これはAgentが担います。
なので、LocalBrainはNNがあったり、行動の予測はしたりしますが、
実際にループを回すのはWorkerThreadやEnvironmentで、予測した行動から報酬を計算するのは、
Agentクラスで行います。

その他

これには関係ないけどGorila(General Reinforcement Learning Architecture)もマークしておいた方がよさそうだ。

# 関数名がアンダースコア2つから始まるものは「外部から参照されない関数」、「1つは基本的に参照しない関数」という意味
というpythonルールがある。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です