Q-learning test

by uwi
できるだけ明るいところを渡るようにと言っておいたはずなんですが・・
♥0 | Line 134 | Modified 2009-10-04 15:30:15 | MIT License
play

ActionScript3 source code

/**
 * Copyright uwi ( http://wonderfl.net/user/uwi )
 * MIT License ( http://www.opensource.org/licenses/mit-license.php )
 * Downloaded from: http://wonderfl.net/c/mNSu
 */

package {
    import flash.utils.*;
    import flash.display.*;
    import flash.text.*;
    import flash.geom.*;
    
    // できるだけ明るいところを渡るようにと言っておいたはずなんですが・・
    public class FlashTest extends Sprite {
        public function FlashTest() {
            var bmd : BitmapData = new BitmapData(465, 465, false, 0x000000);
            addChild(new Bitmap(bmd));
            
            var tf : TextField = new TextField();
            addChild(tf);
            tf.textColor = 0xffffff;
            tf.height = 465;
            
//            bmd.perlinNoise(465, 465, 6, 0, true, false, 7, true);
            var i : int;
            for(i = 0;i < 100;i++){
                bmd.fillRect( 
                    new Rectangle(Math.random() * 465, Math.random() * 465, Math.random() * 50, Math.random() * 50), 
                    int(Math.random() * 256) * (65536+256+1)
                    );
            }
            
            var ql : QLearner = new QLearner(93 * 93, 3,
                function(state : int, action : int) : int {
                    return state + (action - 1) + 93;
                },
                function(state : int, t : int) : Number {
                    var y : int = (state - t * 93) * 5;
//                    tf.appendText("re" + y + "\n");
                    if(y < 0 || y > 465)return 0;
                    return bmd.getPixel(t * 5, y) & 255;
                },
                93 / 2,
                93,
                0.8,
                0.8
                );
            
            ql.init();
            var s : int = getTimer();
            for(i = 0;i < 10000;i++){
                if(i % 2000 == 0)tf.appendText("" + ql.getMaxQ() + "\n");
                ql.learn();
//                tf.appendText("" + ql.learn() + "\n");
            }
            var g : int = getTimer();
            tf.appendText("" + (g - s) + " ms");
            
            var w : Array = ql.getOptimizedWay();
            var state : int = 230;
            for(i = 0;i < 93;i++){
                bmd.setPixel(i * 5, state, 0xff0000);
                state += (w[i] - 1) * 5; 
            }
        }
    }
}

class QLearner
{
    private var _reward : Function;
    private var _update : Function;
    
    private var _Q : Array;
    private var _nState : int;
    private var _nAction : int;
    
    private var _alpha : Number;
    private var _gamma : Number;
    
    private var _iniState : int;
    private var _tlim : int;
    
    public function QLearner(nState : int, nAction : int, update : Function, reward : Function, iniState : int, tlim : int, alpha : Number = 0.1, gamma : Number = 0.1)
    {
        _nState = nState;
        _nAction = nAction;
        _reward = reward;
        _update = update;
        _alpha = alpha;
        _gamma = gamma;
        _iniState = iniState;
        _tlim = tlim;
    }
    
    public function init() : void
    {
        _Q = new Array(_nState * _nAction);
        for(var i : int = 0;i < _Q.length;i++)_Q[i] = 0;
    }
        
    public function learn() : Array
    {
        var ret : Array = [];
        
        var state : int = _iniState;
        for(var t : int = 0;t < _tlim - 1;t++){
            var a : int = Math.random() * _nAction;
            var r : Number = _reward(state, t);
            var alpha : Number = _alpha / (t + 10);
            
            var nex : int = _update(state, a);
            var maxq : Number = 0;
            for(var aa : int = 0;aa < _nAction;aa++){
                if(maxq < _Q[nex * _nAction + aa])maxq = _Q[nex * _nAction + aa];
            }
            _Q[state * _nAction + a] = (1 - alpha) * _Q[state * _nAction + a] + alpha * (r + _gamma * maxq);
            
            state = nex;
            ret.push(state);
        }
        return ret;
    }
    
    public function getOptimizedWay() : Array
    {
        var ret : Array = new Array(_tlim);
        var state : int = _iniState;
        for(var t : int = 0;t < _tlim;t++){
            var maxq : Number = -1;
            var maxaa : int = -1;
            for(var aa : int = 0;aa < _nAction;aa++){
                if(maxq < _Q[state * _nAction + aa]){
                    maxq = _Q[state * _nAction + aa];
                    maxaa = aa;
                }
            }
            state = _update(state, maxaa);
            ret[t] = maxaa;
        }
        return ret;
    }
    
    public function getMaxQ() : Number
    {
        var ret : Number = 0;
        var state : int = _iniState;
        for(var t : int = 0;t < _tlim;t++){
            var maxq : Number = -1;
            var maxaa : Number = -1;
            for(var aa : int = 0;aa < _nAction;aa++){
                if(maxq < _Q[state * _nAction + aa]){
                    maxq = _Q[state * _nAction + aa];
                    maxaa = aa;
                }
            }
            state = _update(state, maxaa);
            ret += Math.pow(_gamma, t) * maxq;
        }
        return ret;
    }
}

Forked