/*
 * Decompiled with CFR 0.152.
 */
package com.github.rjeschke.neetutils.ai;

import com.github.rjeschke.neetutils.ai.Layer;
import com.github.rjeschke.neetutils.ai.Net;
import com.github.rjeschke.neetutils.ai.Trainer;

@Deprecated
public class BackpropMomentumTrainer
implements Trainer {
    double step;
    double alpha;
    Net net;
    Net oldDeltas;
    double min;
    double max;
    double sum;

    public BackpropMomentumTrainer(Net net, double step, double alpha) {
        this.net = net;
        this.step = step;
        this.alpha = alpha;
        this.oldDeltas = net.clone().clear();
    }

    public double getDeltaChangeMinimum() {
        return this.min;
    }

    public double getDeltaChangeMaximum() {
        return this.max;
    }

    public double getDeltaChangeAverage() {
        return this.sum;
    }

    @Override
    public void train(double[] input, double[] expectedOutput) {
        int i;
        Layer.State[] netState = this.net.createExtraStates(input);
        Layer.State[] deltas = this.net.createExtraStates(new double[this.net.numInputs]);
        this.net.run(netState);
        double[] os = netState[netState.length - 1].values;
        double[] ds = deltas[netState.length - 1].values;
        for (i = 0; i < os.length; ++i) {
            double o = os[i];
            ds[i] = o * (1.0 - o) * (expectedOutput[i] - o);
        }
        for (i = this.net.layers.length - 1; i >= 0; --i) {
            Layer l = this.net.layers[i];
            os = netState[i].values;
            ds = deltas[i + 1].values;
            double[] dso = deltas[i].values;
            for (int x = 0; x < l.numInputs; ++x) {
                double e = 0.0;
                for (int y = 0; y < l.numOutputs; ++y) {
                    e += ds[y] * l.matrix[y * l.width + x];
                }
                double o = os[x];
                dso[x] = o * (1.0 - o) * e;
            }
        }
        this.sum = 0.0;
        this.max = 0.0;
        this.min = Double.MAX_VALUE;
        int runs = 0;
        for (int i2 = 0; i2 < this.net.layers.length; ++i2) {
            Layer l = this.net.layers[i2];
            Layer l2 = this.oldDeltas.layers[i2];
            runs += (l.numInputs + 1) * l.numOutputs;
            for (int y = 0; y < l.numOutputs; ++y) {
                int p = y * l.width;
                double d = this.step * deltas[i2 + 1].values[y];
                int n = p + l.numInputs;
                double d2 = this.updateDeltas(d + this.alpha * l2.matrix[p + l.numInputs]);
                l2.matrix[p + l.numInputs] = d2;
                l.matrix[n] = l.matrix[n] + d2;
                for (int x = 0; x < l.numInputs; ++x) {
                    int n2 = p + x;
                    double d3 = this.updateDeltas(d * netState[i2].values[x] + this.alpha * l2.matrix[p + x]);
                    l2.matrix[p + x] = d3;
                    l.matrix[n2] = l.matrix[n2] + d3;
                }
            }
        }
        if (runs != 0) {
            this.sum /= (double)runs;
        }
    }

    private double updateDeltas(double delta) {
        double da = Math.abs(delta);
        this.min = Math.min(this.min, da);
        this.max = Math.max(this.max, da);
        this.sum += da;
        return delta;
    }

    public void setStep(double v) {
        this.step = v;
    }

    public void setAlpha(double v) {
        this.alpha = v;
    }
}

