MCMCで固定弾幕回避学習

by uwi forked from GAで固定弾幕回避学習 (diff: 360)
MCMC+きつめのアニーリングで、ストローク数が小さい領域での回避解を探す。
右側の欄に追尾弾の情報をいれてStart To Learn.

評価値 : ((最初に当たるまで生きていた時間) - 0.01*(ストローク数-1) - 0.001*(非0の個数)) / (シミュレート時間)
0が不移動, 1が右で、以降反時計回りに動きを割り当てている。
99.83点より上はノーミス、600フレーム分を学習。
 左上の表示は
現在のコードのストローク数 現在のコードの評価値
最適コードのストローク数 最適コードの評価値
ステップ数
♥0 | Line 385 | Modified 2009-11-05 05:55:05 | MIT License
play

ActionScript3 source code

/**
 * Copyright uwi ( http://wonderfl.net/user/uwi )
 * MIT License ( http://www.opensource.org/licenses/mit-license.php )
 * Downloaded from: http://wonderfl.net/c/r2if
 */

// forked from uwi's GAで固定弾幕回避学習
package {
    import flash.text.TextField;
    import flash.display.*;
    import flash.filters.*;
    import flash.geom.*;
    import flash.events.*;
    import flash.ui.*;
    import com.bit101.components.*;
    
    // MCMC+きつめのアニーリングで、ストローク数が小さい領域での回避解を探す。
    // 右側の欄に追尾弾の情報をいれてStart To Learn.
    // 
    // 評価値 : ((最初に当たるまで生きていた時間) - 0.01*(ストローク数-1) - 0.001*(非0の個数)) / (シミュレート時間)
    // 0が不移動, 1が右で、以降反時計回りに動きを割り当てている。
    // 99.83点より上はノーミス、600フレーム分を学習。
    
    // 左上の表示は
    // 現在のコードのストローク数 現在のコードの評価値
    // 最適コードのストローク数 最適コードの評価値
    // ステップ数

    [SWF(backgroundColor="#000000", frameRate="30")]
    public class MCMCAvoider extends Sprite {
        private var _bullets : Array;
        private var _myx : Point;
        private var _nhit : int;
        private const R_ME : Number = 5.0;
        private var _T : int = 600;
        private var _CR : Number = 0.01;
        
        private var _tf : TextField;
        private var _tfinput : TextField;
        private var _tfcr : TextField;
        private var _tft : TextField;
        private var _tfresult : TextField;
        private var _submit : PushButton;
        private var _stop : PushButton;
        
        private var _shotPattern : Array;
        private var _av : MCMCLearner;
        
        private var W : Number = 400;
        private var H : Number = 465;
        
        private var _state : int;
        
        public function MCMCAvoider() {
            Wonderfl.capture_delay(5);
            
            var tfinputhead : TextField = new TextField();
            setParams(tfinputhead, {
                text : "x y r v interval",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 370,
                y : 0,
                width : 90,
                height : 20
            });
            addChild(tfinputhead);
            
            _tfinput = new TextField();
            setParams(_tfinput, {
                type : "input",
                text : "0 0 20 10 6\n150 0 20 10 6\n400 0 20 10 6",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 370,
                y : 20,
                width : 90,
                height : 150
            });
            addChild(_tfinput);
            
            // 冷却速度
            var tfcrhead : TextField = new TextField();
            setParams(tfcrhead, {
                text : "chilling rate",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 370,
                y : 180,
                width : 60,
                height : 20
            });
            addChild(tfcrhead);
            
            _tfcr = new TextField();
            setParams(_tfcr, {
                type : "input",
                text : "0.0005",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 430,
                y : 180,
                width : 30,
                height : 20
            });
            addChild(_tfcr);
            
            // シミュレーション時間
            var tfthead : TextField = new TextField();
            setParams(tfthead, {
                text : "time",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 370,
                y : 200,
                width : 60,
                height : 20
            });
            addChild(tfthead);
            
            _tft = new TextField();
            setParams(_tft, {
                type : "input",
                text : "600",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 430,
                y : 200,
                width : 30,
                height : 20
            });
            addChild(_tft);
            
            
            _tfresult = new TextField();
            setParams(_tfresult, {
                text : "[result]",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true,
                x : 0,
                y : 280,
                width : 460,
                height : 180,
                wordWrap : true
            });
            addChild(_tfresult);
            
            _submit = new PushButton(this, 370, 230, "Start To Learn", onSubmit);
            _submit.width = 90;

            _stop = new PushButton(this, 370, 250, "Stop/Resume Learning", onStop);
            _stop.width = 90;

            // デバッグ用
            _tf = new TextField();
            setParams(_tf, {
                autoSize : "left",
                textColor : 0xffffff,
                borderColor : 0xffffff,
                border : true
            });
            addChild(_tf);
            
            _state = 0;
        }
        
        private static function setParams(t : Object, v : Object) : Object
        {
            for(var k : String in v){
                t[k] = v[k];
            }
            return t;
        }
        
        private function onSubmit(e : MouseEvent) : void
        {
            removeEventListener(Event.ENTER_FRAME, onLearnStep);
            
            _state = 1;

            _shotPattern = [];
            for each(var line : String in _tfinput.text.split(/[\r\n]/)){
                var seg : Array = line.split(' ');
                if(seg.length == 5){
                    if(Number(seg[2]) > 0 && Number(seg[3]) > 0 && int(seg[4]) > 0){
                        _shotPattern.push({
                            x : Number(seg[0]),
                            y : Number(seg[1]),
                            r : Number(seg[2]),
                            v : Number(seg[3]),
                            interval : int(seg[4])
                        });
                    }
                }
            }
            
            _T = int(_tft.text);
            var cr : Number = Number(_tfcr.text);
            _av = new MCMCLearner(_T, simulate, cr, _tf);
            addEventListener(Event.ENTER_FRAME, onLearnStep);
            _g = 0;
        }
        
        private function onStop(e : MouseEvent) : void
        {
            if(_state == 0)return;
            if(hasEventListener(Event.ENTER_FRAME)){
                removeEventListener(Event.ENTER_FRAME, onLearnStep);
            }else{
                addEventListener(Event.ENTER_FRAME, onLearnStep);
            }
        }
        
        private var _g : int;
        
        private function onLearnStep(e : Event) : void
        {
            for(var i : int = 0;i < 100;i++){
                _g++;
                _av.step();
            }
             
            _tf.text = "" + 
                _av.Cur.length + "\t" + _av.CurScore + "\n" +
                _av.Elite.length + "\t" + _av.EliteScore + "\n" + 
                "step : " + _g + "\n";
                 
            var el : Array = _av.Elite;
            var elstr : String = "";
            var p : int = 0;
            var c : String = "";
            for(i = 0;i < _T;i++){
                if(p < el.length && i == el[p].t){
                    c = el[p].op.toString();
                    p++;
                }
                elstr += c;
            }
            _tfresult.text = elstr;
        }
        
        private const ST : Array = [
            [0, 0],
            [6, 0], [4, -4], [0, -6], [-4, -4],
            [-6, 0], [-4, 4], [0, 6], [4, 4]
            ];
        
        private function init() : void
        {
            _myx = new Point(W / 2, H / 2);
            _bullets = [];
        }
        
        // 0 : no move
        // 1 : R から左回り
        private function simulate(code : Array) : Number
        {
            init();
            _nhit = 0;
            
            var t : int;
            var nOperate : int = 0;
            var nNon0 : int = 0;
            
            var p : int = -1;
            var nextt : int = 0;
            var curST : Array;
            for(t = 0;t < _T;t++){
                if(judge())break;
                
                for each(var ptn : Object in _shotPattern){
                    if(t % ptn.interval == 0){
                        addBullet(ptn.x, ptn.y, ptn.r, ptn.v);
                    }
                }
                
                if(t == nextt){
                    p++;
                    nextt = p + 1 < code.length ? code[p + 1].t : -1;
                    curST = ST[code[p].op];
                    if(p > 0 && code[p - 1].op != code[p].op){
                        nOperate++;
                    }
                    if(code[p].op != 0)nNon0++;
                }
                _myx.x += curST[0];
                _myx.y += curST[1];
                
                moveBullets();
            }
            
//            _tf.appendText("" + _nhit + "\n");
            return (t - nOperate * 0.01 - nNon0 * 0.001) / _T * 100;
        }
        
        private function moveBullets() : void
        {
            // 弾
            for each(var b : Bullet in _bullets){
                b.xx += b.vx;
                b.xy += b.vy;
            }
        }
        
        // 弾削除
        private function removeBullet(i : int) : void
        {
            if(i < _bullets.length - 1){
                _bullets[i] = _bullets.pop();
            }else{
                _bullets.pop();
            }
        }
        
        // 当たり判定
        private function judge() : Boolean
        {
            var ret : Boolean = false;
            for(var i : int = _bullets.length - 1;i >= 0;i--){
                var b : Bullet = _bullets[i];
                if(
                    (b.xx - _myx.x) * (b.xx - _myx.x) + 
                    (b.xy - _myx.y) * (b.xy - _myx.y)
                    < (b.r + R_ME) * (b.r + R_ME)){
                        _nhit++;
                        removeBullet(i);
                        ret = true;
                        continue;
                }
                if(b.xx < 0 || b.xx > W || b.xy < 0 || b.xy > H){
                    removeBullet(i);
                }
            }
            if(_myx.x < 0 || _myx.x > W || _myx.y < 0 || _myx.y > H){
                _nhit+=10;
                ret = true;
                _myx.x = W / 2;
                _myx.y = H / 2;
            }
            return ret;
        }
        
        // 弾追加
        private function addBullet(x : Number, y : Number, r : Number, v : Number) : void
        {
            var vr : Number = Math.sqrt((_myx.x - x) * (_myx.x - x) + (_myx.y - y) * (_myx.y - y));

            var b : Bullet = new Bullet();
            b.xx = x;
            b.xy = y;
            b.vx = v * (_myx.x - x) / vr;
            b.vy = v * (_myx.y - y) / vr;
            b.r = r;
            _bullets.push(b);
        }        
    }
}

class Bullet
{
    public var xx : Number;
    public var xy : Number;
    public var vx : Number;
    public var vy : Number;
    public var r : Number;
}

import flash.text.TextField;

class MCMCLearner
{
    private var _cur : Array;
    private var _curscore : Number;
    
    private var _elite : Array;
    private var _elitescore : Number;
    
    private var _T : int; // シミュレートする時間
    private var _cr : Number; // 冷却速度
    private var _eval : Function; // 評価値を計算する関数
    
    private var _deb : TextField;
    
    public function MCMCLearner(T : int, eval : Function, cr : Number, deb : TextField)
    {
        _deb = deb;
        _eval = eval;
        _cr = cr;
        _T = T;
        
        init();
    }
    
    public function init() : void
    {
        _cur = generate();
        _curscore = _eval(_cur);
        _elite = clone(_cur);
        _elitescore = _curscore;
        
        _nstep = 0;
    }
    
    // 生成
    private function generate() : Array
    {
        /*
        var code : Array = [];
        var p : Number = 0.1;
        var i : int;
        
        code.push({t : 0, op : int(Math.random() * 9)});
        for(i = 1;i < _T;i++){
            if(Math.random() < p){
                code.push({t : i, op : int(Math.random() * 9)});
            }
        }
        return code;
        */
        return [{t : 0, op : 0}];
    }
    
    private function clone(a : Array) : Array
    {
        // deepcopy
        var code : Array = [];
        for each(var o : Object in a){
            code.push({t : o.t, op : o.op});
        }
        return code;
    }
    
    private var _nstep : int;
    
    // 進化
    public function step() : void
    {
        var next : Array = makeNext();
        var score : Number = _eval(next);
        var p : Number = Math.exp((100/_curscore - 100/score) * (1.0 + _cr * _nstep));
        if(Math.random() < p){
            _cur = next;
            _curscore = score;
            
            if(score > _elitescore){
                _elite = clone(_cur);
                _elitescore = _curscore;
            }
        }
        _nstep++;
    }
    
    // 無性生殖
    private function makeNext() : Array
    {
        var i : int;
        
        var code : Array = clone(_cur);
        
        if(Math.random() < 0.5){
            // 挿入
//            var t : int = Math.random() * _T;
            var t : int = Math.random() * (_curscore / 100 * _T); // 引っかかったところ以前を重点的に
            for(i = 0;i < code.length && t >= code[i].t;i++){
                if(t == code[i].t){
                    code[i].op = int(Math.random() * 9);
                    break;
                }
            }
            if(i == code.length || t < code[i].t){
                if(i == code.length)i == -1;
                code.splice(i, 0, {t : t, op : int(Math.random() * 9)});
            }
        }else{
            // 削除
            if(code.length > 1){
                var ind : int = Math.random() * (code.length - 1) + 1;
                code.splice(ind, 1);
            }
        }
        
        return code;
    }
    
    public function get Cur() : Array { return _cur; }
    public function get CurScore() : Number { return _curscore; }
    public function get Elite() : Array { return _elite; }
    public function get EliteScore() : Number { return _elitescore; }
}