const utils = require('./utils');

export default class NeuralNetwork {
    constructor(neuronCounts) {
        this.levels = [];

        for (let i = 0; i < neuronCounts.length - 1; i++) {
            this.levels.push(new Level(neuronCounts[i], neuronCounts[i+1]) )
        }
    }

    static feedForward(givenInputs, network) {
        //debugger;
        let outputs = Level.feedForward(givenInputs, network.levels[0]);
        for (let i = 1; i < network.levels.length; i++) {
            outputs = Level.feedForward(outputs, network.levels[i]);
        }
        return outputs;
    }

    static backPropagate(expected, network) {
        
        for (let i = network.levels.length - 1; i >= 0; i--) {

            if (i == network.levels.length - 1)
            {
                //console.log(i, expected, network.levels[i])
                Level.BackPropOutput(expected, network.levels[i]);
            }
            else
            {
                //console.log(i, expected, network.levels[i])
                Level.BackPropHidden(network.levels[i + 1].gamma, network.levels[i + 1].weights, network.levels[i]);
            }


        }

        for (let i = 0; i < network.levels.length; i++)
        {
            //debugger;
            Level.UpdateWeights(network.levels[i]);
        }
    }

    static zeroOut(network) {
        for (let i = 0; i < network.levels.length; i++)
        {
            Level.zero(network.levels[i]);
        } 
    }


    static mutate(network,amount=1){
        network.levels.forEach(level => {
            for(let i=0;i<level.biases.length;i++){
                level.biases[i]=utils.lerp(
                    level.biases[i],
                    Math.random()*2-1,
                    amount
                )
            }
            for(let i=0;i<level.weights.length;i++){
                for(let j=0;j<level.weights[i].length;j++){
                    level.weights[i][j]=utils.lerp(
                        level.weights[i][j],
                        Math.random()*2-1,
                        amount
                    )
                }
            }
        });
    }

}


export class Level {


    constructor(inputCount, outputCount) {

        this.inputs = new Array(inputCount);    // Activations
        this.outputs = new Array(outputCount);  // Neuron

        this.weights = [];
        this.weightsDelta = [];

        this.biases = new Array(outputCount);

        //this.inputsRef = [];
        this.errors = new Array(outputCount);
        this.gamma = new Array(outputCount);

        for (let i = 0; i < inputCount; i++) {
            this.weights[i] = new Array(outputCount);
            this.weightsDelta[i] = new Array(outputCount);
        }

        for (let j = 0; j < outputCount; j++) {
            this.errors[j] = 0;
            this.gamma[j] = 0;
        }

        Level.#randomize(this);
    }

    static zero(level) {
        for (let i = 0; i < level.inputs.length; i++) {
            for (let j = 0; j < level.outputs.length; j++) {
                level.weights[i][j] = 0.5;
            }
        }
    }

    static #randomize(level) {
        for (let i = 0; i < level.inputs.length; i++) {
            for (let j = 0; j < level.outputs.length; j++) {
                level.weights[i][j] = Math.random() * 2 - 1;
            }
        }

        for (let i = 0; i < level.biases.length; i++) {
            level.biases[i] = Math.random() * 2 - 1;
        }
    }

    // static backPropagate(givenActuals) {

    //     // Calculate cost = (y_pred - y_act) ^ 2

    //     // 2 * (y_pred - y_act ) * g1
    //     // 2 * (level.outputs[output] - givenActuals[output]) * g1
    //     // 2 * (level.outputs[output] - givenActuals[output]) * g1

    // }

    static feedForward(givenInputs, level) {

        //this.inputsRef = givenInputs;

        for (let i = 0; i < level.inputs.length; i++) {
            level.inputs[i] = givenInputs[i];
        }

        for (let j = 0; j < level.outputs.length; j++) {

            
            let sum = 0;

            for (let i = 0; i < level.inputs.length; i++) {
                sum += level.inputs[i] * level.weights[i][j];
            }

            sum += level.biases[j];


            //level.outputs[j] = utils.ReLU(sum);

            level.outputs[j] = sum / level.inputs.length;

            if (level.outputs[j] > 1) {
                level.outputs[j] = 1;
            }
            if (level.outputs[j] < -1) {
                level.outputs[j] = -1;
            }
        }

        // for (let output = 0; output < level.outputs.length; output++) {
        //     let sum = 0;

        //     for (let i = 0; i < level.inputs.length; i++) {
        //         sum += level.inputs[i] * level.weights[i][output];
        //     }

        //     if (sum > level.biases[output]) {
        //         level.outputs[output] = 1;
        //     } else {
        //         level.outputs[output] = 0;
        //     }
        // }

        return level.outputs;
    }

    static BackPropOutput(expected, level)
    {
        for (let i = 0; i < level.outputs.length; i++)
        {
            level.errors[i] = level.outputs[i] - expected[i];
        }

        for (let i = 0; i < level.outputs.length; i++)
        {
            level.gamma[i] = level.errors[i] * Level.TanHDer(level.outputs[i]);
        }

        
        for (let i = 0; i < level.inputs.length; i++)
        {
            for (let j = 0; j < level.outputs.length; j++)
            {
                level.weightsDelta[i][j] = level.gamma[j] * level.inputs[i];
            }
        }
        

    }

    static BackPropHidden(gammaForward, weightsForward, level)
    {
        for (let j = 0; j < level.outputs.length; j++)
        {
            level.gamma[j] = 0;

            for (let i = 0; i < gammaForward.length; i++)
            {
                level.gamma[j] += gammaForward[i] * weightsForward[j][i];
            }

            level.gamma[j] *= Level.TanHDer(level.outputs[j]);
        }

        for (let i = 0; i < level.inputs.length; i++)
        {
            for (let j = 0; j < level.outputs.length; j++)
            {
                level.weightsDelta[i][j] = level.gamma[j] * level.inputs[i];
            }
        }

    }

    static TanHDer(value)
    {
        return 1.0 - (value * value);
    }

    static UpdateWeights(level)
    {
        const learningRate = 0.005;

        for (let i = 0; i < level.inputs.length; i++)
        {
            for (let j = 0; j < level.outputs.length; j++)
            {
                if (level.weightsDelta[i][j]) {
                    //var old = level.weights[i][j];
                    level.weights[i][j] -= level.weightsDelta[i][j] * learningRate;
                    // if (old != level.weights[i][j]) {
                    //     console.log("New weight", old, level.weights[i][j]);
                    // }
                }
            }
        }
    }

}