/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.tabular;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.core.SparseMax;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.GhostBatchNorm;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public final class TabNet
extends AbstractBlock {
    private static final byte VERSION = 1;
    private Block firstStep;
    private List<Block> steps;
    private Block fullyConnected;
    private Block batchNorm;
    private int numD;
    private int numA;

    private TabNet(Builder builder) {
        super((byte)1);
        int i;
        this.batchNorm = this.addChildBlock("batchNorm", (Block)BatchNorm.builder().optMomentum(builder.batchNormMomentum).build());
        ArrayList<Block> sharedBlocks = new ArrayList<Block>();
        for (i = 0; i < builder.numShared; ++i) {
            sharedBlocks.add(this.addChildBlock("sharedfc" + i, (Block)Linear.builder().setUnits(2L * (long)(builder.numA + builder.numD)).build()));
        }
        this.firstStep = this.addChildBlock("featureTransformer", TabNet.featureTransformer(sharedBlocks, builder.numD + builder.numA, builder.numIndependent, builder.virtualBatchSize, builder.batchNormMomentum));
        this.steps = new ArrayList<Block>();
        for (i = 0; i < builder.numSteps - 1; ++i) {
            this.steps.add(this.addChildBlock("steps" + (i + 1), (Block)new DecisionStep(builder.inputDim, builder.numD, builder.numA, sharedBlocks, builder.numIndependent, builder.virtualBatchSize, builder.batchNormMomentum)));
        }
        this.fullyConnected = this.addChildBlock("fullyConnected", (Block)Linear.builder().setUnits((long)builder.finalOutDim).build());
        this.numD = builder.numD;
        this.numA = builder.numA;
    }

    public static NDArray tabNetGLU(NDArray array, int units) {
        return array.get(":,:{}", new Object[]{units}).mul(Activation.sigmoid((NDArray)array.get(":, {}:", new Object[]{units})));
    }

    public static NDList tabNetGLU(NDList arrays, int units) {
        return new NDList(new NDArray[]{TabNet.tabNetGLU(arrays.singletonOrThrow(), units)});
    }

    public static Block tabNetGLUBlock(int units) {
        return new LambdaBlock(arrays -> TabNet.tabNetGLU(arrays, units), "tabNetGLU");
    }

    public static Block gluBlock(Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum) {
        SequentialBlock featureBlock = new SequentialBlock();
        int units = 2 * outDim;
        if (sharedBlock == null) {
            featureBlock.add((Block)Linear.builder().setUnits((long)units).build());
        } else {
            featureBlock.add(sharedBlock);
        }
        featureBlock.add((Block)((GhostBatchNorm.Builder)GhostBatchNorm.builder().optVirtualBatchSize(virtualBatchSize).optMomentum(batchNormMomentum)).build()).add(TabNet.tabNetGLUBlock(outDim));
        return featureBlock;
    }

    public static Block featureTransformer(List<Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum) {
        ArrayList<Block> allBlocks = new ArrayList<Block>();
        if (!sharedBlocks.isEmpty()) {
            for (Block sharedBlock : sharedBlocks) {
                allBlocks.add(TabNet.gluBlock(sharedBlock, outDim, virtualBatchSize, batchNormMomentum));
            }
        }
        for (int i = 0; i < numIndependent; ++i) {
            allBlocks.add(TabNet.gluBlock(null, outDim, virtualBatchSize, batchNormMomentum));
        }
        SequentialBlock featureBlocks = new SequentialBlock();
        int startIndex = 0;
        if (!sharedBlocks.isEmpty()) {
            startIndex = 1;
            featureBlocks.add((Block)allBlocks.get(0));
        }
        for (int i = startIndex; i < allBlocks.size(); ++i) {
            featureBlocks.add((Block)new ParallelBlock(ndLists -> {
                NDList unit = (NDList)ndLists.get(0);
                NDList parallel = (NDList)ndLists.get(1);
                return new NDList(new NDArray[]{NDArrays.add((NDArray[])new NDArray[]{unit.singletonOrThrow(), parallel.singletonOrThrow()}).mul((Number)Math.sqrt(0.5))});
            }, Arrays.asList((Block)allBlocks.get(i), Blocks.identityBlock())));
        }
        return featureBlocks;
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDManager manager = inputs.getManager();
        NDArray input = inputs.singletonOrThrow();
        input = input.reshape(new long[]{input.size(0), input.size() / input.size(0)});
        NDArray x = this.batchNorm.forward(parameterStore, new NDList(new NDArray[]{input}), training).singletonOrThrow();
        NDArray xa = this.firstStep.forward(parameterStore, new NDList(new NDArray[]{x}), training).singletonOrThrow().get(":," + this.numD + ":", new Object[0]);
        NDArray sparseLoss = null;
        NDArray out = null;
        NDArray priors = manager.ones(x.getShape());
        for (Block step : this.steps) {
            NDList tempRes = step.forward(parameterStore, new NDList(new NDArray[]{x, xa, priors}), training);
            NDArray xte = (NDArray)tempRes.get(0);
            NDArray loss = (NDArray)tempRes.get(1);
            out = out == null ? Activation.relu((NDArray)xte.get(":,:" + this.numD, new Object[0])) : out.add(Activation.relu((NDArray)xte.get(":,:" + this.numD, new Object[0])));
            xa = xte.get(":," + this.numD + ":", new Object[0]);
            sparseLoss = sparseLoss == null ? loss : sparseLoss.add(loss);
        }
        NDArray finalOutput = this.fullyConnected.forward(parameterStore, new NDList(new NDArray[]{out}), training).singletonOrThrow();
        return new NDList(new NDArray[]{finalOutput, sparseLoss});
    }

    public Shape[] getOutputShapes(Shape[] inputShapes) {
        Shape[] shapes = inputShapes;
        Shape[] xShapes = this.batchNorm.getOutputShapes(shapes);
        Shape[] xaShapes = this.firstStep.getOutputShapes(xShapes);
        xaShapes[0] = Shape.update((Shape)xaShapes[0], (int)(xaShapes[0].dimension() - 1), (long)this.numA);
        shapes = new Shape[]{xShapes[0], xaShapes[0], xShapes[0]};
        Shape outputShape = new Shape(new long[0]);
        Shape lossShape = new Shape(new long[0]);
        for (Block step : this.steps) {
            Shape[] outputShapes = step.getOutputShapes(shapes);
            outputShape = Shape.update((Shape)outputShapes[0], (int)(outputShapes[0].dimension() - 1), (long)this.numD);
            lossShape = outputShapes[1];
        }
        outputShape = this.fullyConnected.getOutputShapes(new Shape[]{outputShape})[0];
        return new Shape[]{outputShape, lossShape};
    }

    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        Shape[] shapes = inputShapes;
        this.batchNorm.initialize(manager, dataType, shapes);
        Shape[] xShapes = this.batchNorm.getOutputShapes(shapes);
        this.firstStep.initialize(manager, dataType, xShapes);
        Shape[] xaShapes = this.firstStep.getOutputShapes(xShapes);
        xaShapes[0] = Shape.update((Shape)xaShapes[0], (int)(xaShapes[0].dimension() - 1), (long)this.numD);
        shapes = new Shape[]{xShapes[0], xaShapes[0], xShapes[0]};
        Shape outputShape = new Shape(new long[0]);
        for (Block step : this.steps) {
            step.initialize(manager, dataType, shapes);
            Shape[] outputShapes = step.getOutputShapes(shapes);
            outputShape = Shape.update((Shape)outputShapes[0], (int)(outputShapes[0].dimension() - 1), (long)this.numD);
        }
        this.fullyConnected.initialize(manager, dataType, new Shape[]{outputShape});
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        int inputDim = 128;
        int finalOutDim = 10;
        int numD = 64;
        int numA = 64;
        int numShared = 2;
        int numIndependent = 2;
        int numSteps = 5;
        int virtualBatchSize = 128;
        float batchNormMomentum = 0.9f;

        public Builder setInputDim(int inputDim) {
            this.inputDim = inputDim;
            return this;
        }

        public Builder setOutDim(int outDim) {
            this.finalOutDim = outDim;
            return this;
        }

        public Builder optNumD(int numD) {
            this.numD = numD;
            return this;
        }

        public Builder optNumA(int numA) {
            this.numA = numA;
            return this;
        }

        public Builder optNumShared(int numShared) {
            this.numShared = numShared;
            return this;
        }

        public Builder optNumIndependent(int numIndependent) {
            this.numIndependent = numIndependent;
            return this;
        }

        public Builder optNumSteps(int numSteps) {
            this.numSteps = numSteps;
            return this;
        }

        public Builder optVirtualBatchSize(int virtualBatchSize) {
            this.virtualBatchSize = virtualBatchSize;
            return this;
        }

        public Builder optBatchNormMomentum(float batchNormMomentum) {
            this.batchNormMomentum = batchNormMomentum;
            return this;
        }

        public Block buildAttentionTransformer(int units) {
            return new AttentionTransformer(10, this.virtualBatchSize, this.batchNormMomentum);
        }

        public Block build() {
            return new TabNet(this);
        }
    }

    public static final class DecisionStep
    extends AbstractBlock {
        private static final Byte VERSION = 1;
        private Block featureTransformer;
        private Block attentionTransformer;

        public DecisionStep(int inputDim, int numD, int numA, List<Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum) {
            super(VERSION.byteValue());
            this.featureTransformer = this.addChildBlock("featureTransformer", TabNet.featureTransformer(shared, numD + numA, nInd, virtualBatchSize, batchNormMomentum));
            this.attentionTransformer = this.addChildBlock("attentionTransformer", (Block)new AttentionTransformer(inputDim, virtualBatchSize, batchNormMomentum));
        }

        protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
            NDArray x = (NDArray)inputs.get(0);
            NDArray a = (NDArray)inputs.get(1);
            NDArray priors = (NDArray)inputs.get(2);
            NDList mask = this.attentionTransformer.forward(parameterStore, new NDList(new NDArray[]{a, priors}), training);
            NDArray sparseLoss = mask.singletonOrThrow().mul((Number)-1).mul(NDArrays.add((NDArray)mask.singletonOrThrow(), (Number)1.0E-10).log());
            NDList x1 = this.featureTransformer.forward(parameterStore, new NDList(new NDArray[]{x}), training);
            return new NDList(new NDArray[]{x1.singletonOrThrow(), sparseLoss});
        }

        public Shape[] getOutputShapes(Shape[] inputShapes) {
            Shape[] xShape = new Shape[]{inputShapes[0]};
            Shape[] aShape = new Shape[]{inputShapes[1], inputShapes[2]};
            Shape[] x1Shape = this.featureTransformer.getOutputShapes(xShape);
            Shape[] lossShape = this.attentionTransformer.getOutputShapes(aShape);
            return new Shape[]{x1Shape[0], lossShape[0]};
        }

        protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
            Shape[] xShape = new Shape[]{inputShapes[0]};
            Shape[] aShape = new Shape[]{inputShapes[1], inputShapes[2]};
            this.attentionTransformer.initialize(manager, dataType, aShape);
            this.featureTransformer.initialize(manager, dataType, xShape);
        }
    }

    public static final class AttentionTransformer
    extends AbstractBlock {
        private static final Byte VERSION = 1;
        private Block fullyConnected;
        private Block batchNorm;
        private Block sparseMax;

        private AttentionTransformer(int inputDim, int virtualBatchSize, float batchNormMomentum) {
            super(VERSION.byteValue());
            this.fullyConnected = this.addChildBlock("fullyConnected", (Block)Linear.builder().setUnits((long)inputDim).build());
            this.batchNorm = this.addChildBlock("ghostBatchNorm", (Block)((GhostBatchNorm.Builder)GhostBatchNorm.builder().optVirtualBatchSize(virtualBatchSize).optMomentum(batchNormMomentum)).build());
            this.sparseMax = this.addChildBlock("sparseMax", (Block)new SparseMax());
        }

        protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
            NDArray x = (NDArray)inputs.get(0);
            NDArray priors = (NDArray)inputs.get(1);
            NDList x1 = this.fullyConnected.forward(parameterStore, new NDList(new NDArray[]{x}), training);
            NDList x2 = this.batchNorm.forward(parameterStore, x1, training);
            return this.sparseMax.forward(parameterStore, new NDList(new NDArray[]{x2.singletonOrThrow().mul(priors)}), training);
        }

        public Shape[] getOutputShapes(Shape[] inputShapes) {
            Shape[] shapes = new Shape[]{inputShapes[0]};
            for (Pair child : this.getChildren()) {
                shapes = ((Block)child.getValue()).getOutputShapes(shapes);
            }
            return shapes;
        }

        protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
            Shape[] shapes = new Shape[]{inputShapes[0]};
            for (Block child : this.getChildren().values()) {
                child.initialize(manager, dataType, shapes);
                shapes = child.getOutputShapes(shapes);
            }
        }
    }
}

