forked from: Q-learning test

by uwi forked from Q-learning test (diff: 299)
♥0 | Line 193 | Modified 2010-05-20 09:12:59 | MIT License
play

ActionScript3 source code

/**
 * Copyright uwi ( http://wonderfl.net/user/uwi )
 * MIT License ( http://www.opensource.org/licenses/mit-license.php )
 * Downloaded from: http://wonderfl.net/c/ky67
 */

// forked from uwi's Q-learning test
package { 
	import flash.utils.*;
	import flash.display.*;
	import flash.text.*;
	import flash.geom.*;
	
	public class FlashTest extends Sprite {
		private var _bmd : BitmapData;
		
		public function FlashTest() {
			_bmd = new BitmapData(465, 465, false, 0x000000);
			addChild(new Bitmap(_bmd));
			
			var tf : TextField = new TextField();
			addChild(tf);
			tf.textColor = 0xffffff;
			tf.height = 465;
			
//			  bmd.perlinNoise(465, 465, 6, 0, true, false, 7, true);
			// 障害物配置
			var i : int;
			_bmd.lock();
			
			var random : XorShift128 = new XorShift128(3);
			for(i = 0;i < 100;i++){
				var c : uint = (random.get32() % 17) * (random.get32() % 17);
				_bmd.fillRect( 
					new Rectangle(random.get32() % 465, random.get32() % 465, random.get32() % 50, random.get32() % 50), 
					int(256 - c) * (65536+256+1)
					);
			}
			var ql : QLearner = new QLearner(
				93 * 93, // 状態数
				3, // アクション数
				select, // 選択関数
				update, // 更新関数
				reward, // 報酬関数
				93 / 2, // 初期状態
				93, // 学習時間
				0.5, // 学習率
				0.9 // 割引率
				);
			
			var s : int = getTimer();
			for(i = 0;i < 5000;i++){
				if(i % 500 == 0)tf.appendText("" + ql.getMaxQ() + "\n");
				ql.learn();
//				  tf.appendText("" + ql.learn() + "\n");
			}
			var g : int = getTimer();
			tf.appendText("" + (g - s) + " ms");
			
			var w : Array = ql.getOptimizedWay();
			var state : int = int(93/2) * 5;
			for(i = 0;i < 93;i++){
				_bmd.setPixel(i * 5, state, 0xff0000);
				state += (w[i] - 1) * 5; 
			}
			_bmd.unlock();
		}
		
		private function select(state : int, Q : Vector.<Vector.<Number>>) : int
		{
//			return uint(Math.random() * 3);
			var ps : Array = [];
			var psum : Number = 0;
			for(var i : uint = 0;i < 3;i++){
				var p : Number = 0.05 + Q[state][i];
				ps.push(p);
				psum += p;
			}
			var pp : Number = Math.random() * psum;
			for(i = 0;i < 3;i++){
				if(pp < ps[i]){
					return i;
				}
				pp -= ps[i];
			}
			return -1;
		};
		
		private function update(state : int, action : int) : int
		{
			return state + (action - 1) + 93;
		}
		
		private function reward(state : int, action : int) : Number
		{
			var next : uint = state + (action - 1) + 93;
			var t : uint = next / 93;
			var y : uint = next % 93;
//			var px : Number = (_bmd.getPixel(t * 5, y * 5) & 255) / 255;
//			return Math.sqrt(px);
			return _bmd.getPixel(t * 5, y * 5) > 0 ? 0 : 1;
		}
		
	}
}

import flash.text.TextField;

class QLearner
{
	private var _select : Function;
	private var _update : Function;
	private var _reward : Function;
	
	private var _Q : Vector.<Vector.<Number>>;
	private var _nState : int;
	private var _nAction : int;
	
	private var _alpha : Number;
	private var _gamma : Number;
	
	private var _iniState : int;
	private var _tlim : int;
	
	private var TF : TextField;
	
	/**
	* @param nState 状態数
	* @param nAction アクション数
	* @param select 選択関数 (state:int, Q:Vector.<Vector<Number>>) : int
	* @param update 更新関数 (state:int, action:int) : int
	* @param reward 報酬関数 (state:int, action:int) : Number
	* @param iniState 初期状態
	* @param tlim 学習時間
	* @param alpha 学習率
	* @param gamma 割引率
	*/
	public function QLearner(nState : int, nAction : int, select : Function, update : Function, reward : Function, iniState : int, tlim : int, alpha : Number = 0.1, gamma : Number = 0.1)
	{
		_nState = nState;
		_nAction = nAction;
		_select = select;
		_update = update;
		_reward = reward;
		_iniState = iniState;
		_tlim = tlim;
		_alpha = alpha;
		_gamma = gamma;
		_Q = new Vector.<Vector.<Number>>(_nState);
		for(var i : uint = 0;i < _nState;i++){
			_Q[i] = new Vector.<Number>(_nAction);
			for(var j : uint = 0;j < _nAction;j++){
				_Q[i][j] = 0.0;
			}
		}
	}
		
	public function learn() : void
	{
		var state : int = _iniState;
		for(var t : int = 0;t < _tlim - 1;t++){
			var action : int = _select(state, _Q); // 行動選択
			var next : int = _update(state, action);
			var reward : Number = _reward(state, action); // 報酬計算
			var alpha : Number = _alpha / (t + 10);
			
			// Q値の更新
			var maxq : Number = 0;
			for(var aa : int = 0;aa < _nAction;aa++){
				if(maxq < _Q[next][aa])maxq = _Q[next][aa];
			}
			_Q[state][action] = (1 - alpha) * _Q[state][action] + alpha * (reward + _gamma * maxq);
			
			state = next;
		}
	}
	
	public function getOptimizedWay() : Array
	{
		var ret : Array = new Array(_tlim);
		var state : int = _iniState;
		for(var t : int = 0;t < _tlim;t++){
			var maxq : Number = -1;
			var maxaa : int = -1;
			for(var aa : int = 0;aa < _nAction;aa++){
				if(maxq < _Q[state][aa]){
					maxq = _Q[state][aa];
					maxaa = aa;
				}
			}
			state = _update(state, maxaa);
			ret[t] = maxaa;
		}
		return ret;
	}
	
	public function getMaxQ() : Number
	{
		var ret : Number = 0;
		var state : int = _iniState;
		var gt : Number = 1;
		for(var t : int = 0;t < _tlim;t++){
			var maxq : Number = -1;
			var maxaa : Number = -1;
			for(var aa : uint = 0;aa < _nAction;aa++){
				if(maxq < _Q[state][aa]){
					maxq = _Q[state][aa];
					maxaa = aa;
				}
			}
			state = _update(state, maxaa);
			ret += gt * maxq;
			gt *= _gamma;
		}
		return ret;
	}
}

class XorShift128
{
    private var a:uint;
    private var b:uint;
    private var c:uint;
    private var d:uint;
    
    public function XorShift128(seed : uint = 0)
    {
        setSeed(seed);
    }
    // 種を与える
    public function setSeed( seed:uint ):void
    {
        a=seed=1812433253*(seed^(seed>>>30))+0;
        b=seed=1812433253*(seed^(seed>>>30))+1;
        c=seed=1812433253*(seed^(seed>>>30))+2;
        d=seed=1812433253*(seed^(seed>>>30))+3;
    }

    // 整数乱数を生成
    public function get32():uint
    {
        var t:uint = a^(a<<11);
        a=b; b=c; c=d;
        return( d=(d^(d>>>19))^(t^(t>>>8)) );
    }
  
}