forked from: Q-learning test
forked from Q-learning test (diff: 299)
ActionScript3 source code
/**
* Copyright uwi ( http://wonderfl.net/user/uwi )
* MIT License ( http://www.opensource.org/licenses/mit-license.php )
* Downloaded from: http://wonderfl.net/c/ky67
*/
// forked from uwi's Q-learning test
package {
import flash.utils.*;
import flash.display.*;
import flash.text.*;
import flash.geom.*;
public class FlashTest extends Sprite {
private var _bmd : BitmapData;
public function FlashTest() {
_bmd = new BitmapData(465, 465, false, 0x000000);
addChild(new Bitmap(_bmd));
var tf : TextField = new TextField();
addChild(tf);
tf.textColor = 0xffffff;
tf.height = 465;
// bmd.perlinNoise(465, 465, 6, 0, true, false, 7, true);
// 障害物配置
var i : int;
_bmd.lock();
var random : XorShift128 = new XorShift128(3);
for(i = 0;i < 100;i++){
var c : uint = (random.get32() % 17) * (random.get32() % 17);
_bmd.fillRect(
new Rectangle(random.get32() % 465, random.get32() % 465, random.get32() % 50, random.get32() % 50),
int(256 - c) * (65536+256+1)
);
}
var ql : QLearner = new QLearner(
93 * 93, // 状態数
3, // アクション数
select, // 選択関数
update, // 更新関数
reward, // 報酬関数
93 / 2, // 初期状態
93, // 学習時間
0.5, // 学習率
0.9 // 割引率
);
var s : int = getTimer();
for(i = 0;i < 5000;i++){
if(i % 500 == 0)tf.appendText("" + ql.getMaxQ() + "\n");
ql.learn();
// tf.appendText("" + ql.learn() + "\n");
}
var g : int = getTimer();
tf.appendText("" + (g - s) + " ms");
var w : Array = ql.getOptimizedWay();
var state : int = int(93/2) * 5;
for(i = 0;i < 93;i++){
_bmd.setPixel(i * 5, state, 0xff0000);
state += (w[i] - 1) * 5;
}
_bmd.unlock();
}
private function select(state : int, Q : Vector.<Vector.<Number>>) : int
{
// return uint(Math.random() * 3);
var ps : Array = [];
var psum : Number = 0;
for(var i : uint = 0;i < 3;i++){
var p : Number = 0.05 + Q[state][i];
ps.push(p);
psum += p;
}
var pp : Number = Math.random() * psum;
for(i = 0;i < 3;i++){
if(pp < ps[i]){
return i;
}
pp -= ps[i];
}
return -1;
};
private function update(state : int, action : int) : int
{
return state + (action - 1) + 93;
}
private function reward(state : int, action : int) : Number
{
var next : uint = state + (action - 1) + 93;
var t : uint = next / 93;
var y : uint = next % 93;
// var px : Number = (_bmd.getPixel(t * 5, y * 5) & 255) / 255;
// return Math.sqrt(px);
return _bmd.getPixel(t * 5, y * 5) > 0 ? 0 : 1;
}
}
}
import flash.text.TextField;
class QLearner
{
private var _select : Function;
private var _update : Function;
private var _reward : Function;
private var _Q : Vector.<Vector.<Number>>;
private var _nState : int;
private var _nAction : int;
private var _alpha : Number;
private var _gamma : Number;
private var _iniState : int;
private var _tlim : int;
private var TF : TextField;
/**
* @param nState 状態数
* @param nAction アクション数
* @param select 選択関数 (state:int, Q:Vector.<Vector<Number>>) : int
* @param update 更新関数 (state:int, action:int) : int
* @param reward 報酬関数 (state:int, action:int) : Number
* @param iniState 初期状態
* @param tlim 学習時間
* @param alpha 学習率
* @param gamma 割引率
*/
public function QLearner(nState : int, nAction : int, select : Function, update : Function, reward : Function, iniState : int, tlim : int, alpha : Number = 0.1, gamma : Number = 0.1)
{
_nState = nState;
_nAction = nAction;
_select = select;
_update = update;
_reward = reward;
_iniState = iniState;
_tlim = tlim;
_alpha = alpha;
_gamma = gamma;
_Q = new Vector.<Vector.<Number>>(_nState);
for(var i : uint = 0;i < _nState;i++){
_Q[i] = new Vector.<Number>(_nAction);
for(var j : uint = 0;j < _nAction;j++){
_Q[i][j] = 0.0;
}
}
}
public function learn() : void
{
var state : int = _iniState;
for(var t : int = 0;t < _tlim - 1;t++){
var action : int = _select(state, _Q); // 行動選択
var next : int = _update(state, action);
var reward : Number = _reward(state, action); // 報酬計算
var alpha : Number = _alpha / (t + 10);
// Q値の更新
var maxq : Number = 0;
for(var aa : int = 0;aa < _nAction;aa++){
if(maxq < _Q[next][aa])maxq = _Q[next][aa];
}
_Q[state][action] = (1 - alpha) * _Q[state][action] + alpha * (reward + _gamma * maxq);
state = next;
}
}
public function getOptimizedWay() : Array
{
var ret : Array = new Array(_tlim);
var state : int = _iniState;
for(var t : int = 0;t < _tlim;t++){
var maxq : Number = -1;
var maxaa : int = -1;
for(var aa : int = 0;aa < _nAction;aa++){
if(maxq < _Q[state][aa]){
maxq = _Q[state][aa];
maxaa = aa;
}
}
state = _update(state, maxaa);
ret[t] = maxaa;
}
return ret;
}
public function getMaxQ() : Number
{
var ret : Number = 0;
var state : int = _iniState;
var gt : Number = 1;
for(var t : int = 0;t < _tlim;t++){
var maxq : Number = -1;
var maxaa : Number = -1;
for(var aa : uint = 0;aa < _nAction;aa++){
if(maxq < _Q[state][aa]){
maxq = _Q[state][aa];
maxaa = aa;
}
}
state = _update(state, maxaa);
ret += gt * maxq;
gt *= _gamma;
}
return ret;
}
}
class XorShift128
{
private var a:uint;
private var b:uint;
private var c:uint;
private var d:uint;
public function XorShift128(seed : uint = 0)
{
setSeed(seed);
}
// 種を与える
public function setSeed( seed:uint ):void
{
a=seed=1812433253*(seed^(seed>>>30))+0;
b=seed=1812433253*(seed^(seed>>>30))+1;
c=seed=1812433253*(seed^(seed>>>30))+2;
d=seed=1812433253*(seed^(seed>>>30))+3;
}
// 整数乱数を生成
public function get32():uint
{
var t:uint = a^(a<<11);
a=b; b=c; c=d;
return( d=(d^(d>>>19))^(t^(t>>>8)) );
}
}