package kdo.neuralNetwork.feedforward;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.IntStream;
import kdo.neuralNetwork.optimizer.IOptimizer;
import kdo.neuralNetwork.transfer.ITransferFunction;
import kdo.util.IRandomSource;

/* loaded from: input_file:kdo/neuralNetwork/feedforward/Layer.class */
public abstract class Layer implements Serializable {
    private static final long serialVersionUID = 1;
    protected float[][] weight;
    protected float[] bias;
    protected final float[] outputLayer;
    protected float[] summedInputs;
    protected final float[] inputLayer;
    protected final ITransferFunction transferFunction;
    private int index;

    public Layer(int i, int i2, int i3, ITransferFunction iTransferFunction, IRandomSource iRandomSource) {
        this(i, i2, i3, iTransferFunction);
        if (i2 < 1) {
            throw new IllegalArgumentException("Need at least one input but had: " + i2);
        }
        if (i3 < 1) {
            throw new IllegalArgumentException("Need at least one output but had: " + i3);
        }
        this.weight = new float[i2][i3];
        this.bias = new float[i3];
        initializeWeights(0.0f, 1.0f, iRandomSource);
    }

    public Layer(int i, float[][] fArr, float[] fArr2, ITransferFunction iTransferFunction) {
        this(i, fArr.length, fArr2.length, iTransferFunction);
        this.weight = fArr;
        this.bias = fArr2;
    }

    private Layer(int i, int i2, int i3, ITransferFunction iTransferFunction) {
        this.index = i;
        this.transferFunction = iTransferFunction;
        this.summedInputs = new float[i3];
        this.outputLayer = new float[i3];
        this.inputLayer = new float[i2];
    }

    private void initializeWeights(float f, float f2, IRandomSource iRandomSource) {
        for (float[] fArr : this.weight) {
            for (int i = 0; i < fArr.length; i++) {
                fArr[i] = (((iRandomSource.nextFloat() * 2.0f) * f2) + f) - f2;
            }
        }
        for (int i2 = 0; i2 < this.bias.length; i2++) {
            this.bias[i2] = (((iRandomSource.nextFloat() * 2.0f) * f2) + f) - f2;
        }
    }

    public float[] recall(float[] fArr) {
        checkInputLength(fArr.length);
        System.arraycopy(fArr, 0, this.inputLayer, 0, fArr.length);
        int length = this.weight[0].length;
        for (int i = 0; i < length; i++) {
            this.summedInputs[i] = 0.0f;
            for (int i2 = 0; i2 < this.weight.length; i2++) {
                float[] fArr2 = this.summedInputs;
                int i3 = i;
                fArr2[i3] = fArr2[i3] + (this.weight[i2][i] * fArr[i2]);
            }
            float[] fArr3 = this.summedInputs;
            int i4 = i;
            fArr3[i4] = fArr3[i4] + this.bias[i];
            this.outputLayer[i] = this.transferFunction.transfer(this.summedInputs[i]);
        }
        return this.outputLayer;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float adjustWeights(IOptimizer iOptimizer, float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            float f2 = fArr[i];
            f += f2 * f2;
            for (int i2 = 0; i2 < this.weight.length; i2++) {
                iOptimizer.apply(this.weight[i2], this.inputLayer[i2] * f2, this.index, i2, i);
            }
            iOptimizer.apply(this.bias, f2, this.index, this.weight.length, i);
        }
        return f / fArr.length;
    }

    public int getInputLength() {
        return this.weight.length;
    }

    public int getOutputLength() {
        return this.weight[0].length;
    }

    private void checkInputLength(int i) {
        if (i != this.weight.length) {
            throw new IllegalArgumentException("Input size " + i + " does not match network input size " + this.weight.length);
        }
    }

    public abstract float[] calculateDelta(float[] fArr);

    public float[] propagateDelta(float[] fArr) {
        float[] fArr2 = new float[this.weight.length];
        for (int i = 0; i < this.weight.length; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < fArr.length; i2++) {
                f += fArr[i2] * this.weight[i][i2];
            }
            fArr2[i] = f;
        }
        return fArr2;
    }

    public String toString() {
        return Arrays.deepToString(this.weight) + " bias: " + Arrays.toString(this.bias) + "\n";
    }

    public double[] getOutputLayer() {
        return IntStream.range(0, this.outputLayer.length).mapToDouble(i -> {
            return this.outputLayer[i];
        }).toArray();
    }
}
