/**
 * Copyright jetbead ( http://wonderfl.net/user/jetbead )
 * MIT License ( http://www.opensource.org/licenses/mit-license.php )
 * Downloaded from: http://wonderfl.net/c/uD0w
 */

package 
{
    import flash.display.*;
    import flash.events.*;
    import flash.text.TextField;

    public class Main extends Sprite 
    {
        private const coord_scale:Number = 10;
        private const coord_size:int = 200;
        
        private var coord:Coord2DSprite;
        private var perceptron:Perceptron;
        private var cnt:int = 0;
        
        public function Main():void 
        {
            if (stage) init();
            else addEventListener(Event.ADDED_TO_STAGE, init);
        }
        
        private function init(e:Event = null):void 
        {
            removeEventListener(Event.ADDED_TO_STAGE, init);
            
            stage.align = StageAlign.TOP_LEFT;
            stage.scaleMode = StageScaleMode.NO_SCALE;
            stage.frameRate = 30;
            
            
            var tf:TextField = new TextField();
            tf.x = 0;
            tf.y = 0;
            tf.width = 465;
            tf.height = 50;
            tf.selectable = false;
            tf.text = "座標をクリックすると、正例(橙点)と負例(青点)を交互に入力できます。\n入力すると随時ランダムに点を選んで学習します。\n線形分離可能ならば、いつかは落ち着きます。たぶん。";
            addChild(tf);
            
            
            coord = new Coord2DSprite(coord_size, coord_size, coord_scale, coord_scale);
            coord.x = 233 - coord_size / 2;
            coord.y = 233 - coord_size / 2;
            addChild(coord);
            
            perceptron = new Perceptron(0.0);
            
            
            addEventListener(Event.ENTER_FRAME, draw);
        }
        
        private function draw(e:Event):void {
            if (coord.pointType.length == 0) return;
            if (cnt < 2) { //2フレームぐらい余裕を持たせる
                cnt++;
                return;
            }
            cnt = 0;
            
            //////////////////////////////////////////////////////////////////////
            removeEventListener(Event.ENTER_FRAME, draw);
            
            //ランダムな点を選んで、その点を使って学習させる
            var idx:int = int(Math.random() * coord.pointType.length);
            coord.selectPoint(coord.pointX[idx], coord.pointY[idx]);
            
            perceptron.train(coord.pointType[idx], { "x":coord.pointX[idx], "y":coord.pointY[idx] } );
            
            //現在の分離平面を表示
            draw_w();
            
            addEventListener(Event.ENTER_FRAME, draw);
        }
        
        //分離平面の表示
        private function draw_w():void {
            for (var i:int = 0; i < coord_size; i++) {
                for (var j:int = 0; j < coord_size; j++) {
                    var nx:Number = ( i * 2 * coord_scale ) / coord_size - coord_scale;
                    var ny:Number = ( (coord_size - j) * 2 * coord_scale ) / coord_size - coord_scale;
                    
                    var ret:Number = perceptron.predict( { "x":nx, "y":ny, "bias":1.0 } );
                    if (ret > 0){
                        coord.bg_data.setPixel(i, j, 0xffa500);
                    } else {
                        coord.bg_data.setPixel(i, j, 0x4169e1);
                    }
                }
            }
        }
    }
    
}

//座標処理用Sprite
import flash.display.Sprite;
import flash.display.Bitmap;
import flash.display.BitmapData;
import flash.events.MouseEvent;

class Coord2DSprite extends Sprite {
    public var type:int = 1;
    //クリックされたポイントの座標
    public var pointType:Array;
    public var pointX:Array;
    public var pointY:Array;
    //各部品
    public var bg:Bitmap; //背景(分離後の色分けに使う用)
    public var bg_data:BitmapData;
    public var psp:Sprite; //ポイントを表示する用
    public var circle:Sprite; //現在学習中のポイントを明示する用
    //座標変換用
    private var W:int;
    private var H:int;
    private var minX:Number;
    private var maxX:Number;
    private var minY:Number;
    private var maxY:Number;
    
    public function Coord2DSprite(W_:int, H_:int, scaleX_:Number = 1, scaleY_:Number = 1) {
        W = W_;
        H = H_;
        minX = -scaleX_;
        maxX = scaleX_;
        minY = -scaleY_;
        maxY = scaleY_;
        
        pointType = new Array();
        pointX = new Array();
        pointY = new Array();
        
        bg_data = new BitmapData(W, H);
        bg = new Bitmap(bg_data);
        this.addChild(bg);
        psp = new Sprite();
        this.addChild(psp);
        circle = new Sprite();
        this.addChild(circle);
        
        for (var i:int = 0; i < W; i++) {
            for (var j:int = 0; j < H; j++) {
                bg_data.setPixel(i, j, 0xcccccc);                
            }
        }
        
        psp.graphics.lineStyle(1, 0x000000);
        psp.graphics.moveTo(0, H_ / 2);
        psp.graphics.lineTo(W_, H_ / 2);
        psp.graphics.moveTo(W_ / 2, 0);
        psp.graphics.lineTo(W_ / 2, H_);
        
        circle.graphics.lineStyle(1, 0x000000);
        circle.graphics.beginFill(0x000000, 0.5);
        circle.graphics.drawCircle(0, 0, 4);
        circle.graphics.endFill();
        circle.visible = false;
        
        
        this.addEventListener(MouseEvent.CLICK, onClick);
        psp.addEventListener(MouseEvent.CLICK, onClick);
    }
    
    private function onClick(e:MouseEvent):void {
        var nx:Number = ( e.target.mouseX * (maxX - minX) ) / W + minX;
        var ny:Number = ( (H - e.target.mouseY) * (maxY - minY) ) / H + minY;
        
        pointType.push(type);
        pointX.push(nx);
        pointY.push(ny);
        
        if(type == 1){
            psp.graphics.beginFill(0xffa500);
        } else {
            psp.graphics.beginFill(0x4169e1);
        }
        psp.graphics.drawCircle(e.target.mouseX, e.target.mouseY, 3);
        psp.graphics.endFill();
        
        if (type == 1) {
            type = -1;
        } else {
            type = 1;
        }
    }
    
    public function selectPoint(x:Number, y:Number):void {
        var nx:int = ( (x - minX) / (maxX - minX) ) * W;
        var ny:int = H - ( (y - minY) / (maxY - minY) ) * H;
        circle.visible = true;
        circle.x = nx;
        circle.y = ny;
    }
}

//パーセプトロンによる分類器
class Perceptron {
    private var w:Object;
    public var bias:Number;
    private var margin:Number;
    
    public function Perceptron(margin_:Number = 0.0, bias_:Number = 1.0) {
        bias = bias_;
        margin = margin_;
        
        w = new Object();
        w["bias"] = bias;
    }
    
    //予測
    public function predict(x:Object):Number {
        var ret:Number = 0;
        
        for (var key:String in x) {
            if(!isNaN(w[key])) {
                ret += x[key] * w[key];
            }
        }
        
        return ret;
    }
    
    //まとめて学習
    public function train_all(t:Array, x:Array, loop:int):void {
        for (var l:int = 0; l < loop; l++) {
            for (var i:int = 0; i < t.length; i++) {
                train(t[i], x[i]);
            }
        }
    }
    
    //1回分の学習(SGDによる)
    public function train(t:int, x:Object):void {
        if (isNaN(x["bias"])) {
            x["bias"] = bias;
        }
        
        var f:Number = predict(x);
        //trace( f.toString() + " <=> " + t.toString() + " : " + w["x"] + "," + w["y"] + "," + w["bias"]);
        if (t * f <= margin) {
            for (var key:String in x) {
                if (isNaN(w[key])) {
                    w[key] = t * x[key];
                } else {
                    w[key] += t * x[key];
                }
            }
        }
    }
}