/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.object_detection.yolo;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;

public final class YOLOV3
extends AbstractBlock {
    private static final byte VERSION = 1;
    private SequentialBlock darkNet53;
    private Block lastLayer0;
    private Block layer0Output;
    private Block lastLayer1Conv;
    private Block lastLayer1UpSample;
    private Block lastLayer1;
    private Block layer1Output;
    private Block lastLayer2Conv;
    private Block lastLayer2UpSample;
    private Block lastLayer2;
    private Block layer2Output;
    static final int[] REPEATS = new int[]{1, 2, 8, 8, 4};
    static final int[] FILTERS = new int[]{32, 64, 128, 256, 512, 1024};

    private YOLOV3(Builder builder) {
        super((byte)1);
        this.darkNet53 = (SequentialBlock)this.addChildBlock("darkNet53", (Block)YOLOV3.darkNet53(builder, true));
        this.lastLayer0 = this.addChildBlock("lastLayer0", YOLOV3.makeLastLayers(FILTERS[4], FILTERS[5], builder.batchNormMomentum, builder.leakyAlpha));
        this.layer0Output = this.addChildBlock("layer0Output", YOLOV3.makeOutputLayers(FILTERS[5], 3 * (builder.numClasses + 5), builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer1Conv = this.addChildBlock("lastLayer1Conv", YOLOV3.convolutionBlock(256, 1, builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer1UpSample = this.addChildBlock("lastLayer1UpSample", YOLOV3.upSampleBlockNearest());
        this.lastLayer1 = this.addChildBlock("lastLayer1", YOLOV3.makeLastLayers(FILTERS[3], FILTERS[4], builder.batchNormMomentum, builder.leakyAlpha));
        this.layer1Output = this.addChildBlock("layer1Output", YOLOV3.makeOutputLayers(FILTERS[4], 3 * (builder.numClasses + 5), builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer2Conv = this.addChildBlock("lastLayer2Conv", YOLOV3.convolutionBlock(128, 1, builder.batchNormMomentum, builder.leakyAlpha));
        this.lastLayer2UpSample = this.addChildBlock("lastLayer2UpSample", YOLOV3.upSampleBlockNearest());
        this.lastLayer2 = this.addChildBlock("lastLayer2", YOLOV3.makeLastLayers(FILTERS[2], FILTERS[3], builder.batchNormMomentum, builder.leakyAlpha));
        this.layer2Output = this.addChildBlock("layer2Output", YOLOV3.makeOutputLayers(FILTERS[3], 3 * (builder.numClasses + 5), builder.batchNormMomentum, builder.leakyAlpha));
    }

    public static Block upSampleBlockNearest() {
        return new SequentialBlock().addSingleton(array -> array.transpose(new int[]{0, 2, 3, 1})).addSingleton(array -> NDImageUtils.resize((NDArray)array, (int)((int)(array.getShape().get(1) * 2L)), (int)((int)(array.getShape().get(2) * 2L)), (Image.Interpolation)Image.Interpolation.NEAREST)).addSingleton(array -> array.transpose(new int[]{0, 3, 1, 2}));
    }

    public static Block convolutionBlock(int filters, int kernel, float batchNormMomentum, float leakyAlpha) {
        int pad = 0;
        if (kernel > 0) {
            pad = kernel - 1 >> 1;
        }
        return new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(filters)).setKernelShape(new Shape(new long[]{kernel, kernel}))).optPadding(new Shape(new long[]{pad, pad}))).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation.leakyReluBlock((float)leakyAlpha));
    }

    public static Block makeLastLayers(int filtersIn, int filtersOut, float batchNormMomentum, float leakyAlpha) {
        return new SequentialBlock().add(YOLOV3.convolutionBlock(filtersIn, 1, batchNormMomentum, leakyAlpha)).add(YOLOV3.convolutionBlock(filtersOut, 3, batchNormMomentum, leakyAlpha)).add(YOLOV3.convolutionBlock(filtersIn, 1, batchNormMomentum, leakyAlpha)).add(YOLOV3.convolutionBlock(filtersOut, 3, batchNormMomentum, leakyAlpha)).add(YOLOV3.convolutionBlock(filtersIn, 1, batchNormMomentum, leakyAlpha));
    }

    public static Block makeOutputLayers(int filtersOut, int outClass, float batchNormMomentum, float leakyAlpha) {
        return new SequentialBlock().add(YOLOV3.convolutionBlock(filtersOut, 3, batchNormMomentum, leakyAlpha)).add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(outClass)).setKernelShape(new Shape(new long[]{1L, 1L}))).build());
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList xList = this.darkNet53.forward(parameterStore, inputs, training);
        NDArray x0 = (NDArray)xList.get(7);
        NDArray x1 = (NDArray)xList.get(6);
        NDArray x2 = (NDArray)xList.get(5);
        NDList out0Branch = this.lastLayer0.forward(parameterStore, new NDList(new NDArray[]{x0}), training);
        NDList out0 = this.layer0Output.forward(parameterStore, out0Branch, training);
        NDList x1In = this.lastLayer1Conv.forward(parameterStore, out0Branch, training);
        x1In = this.lastLayer1UpSample.forward(parameterStore, x1In, training);
        x1In = new NDList(new NDArray[]{x1In.singletonOrThrow().concat(x1, 1)});
        NDList out1Branch = this.lastLayer1.forward(parameterStore, x1In, training);
        NDList out1 = this.layer1Output.forward(parameterStore, out1Branch, training);
        NDList x2In = this.lastLayer2Conv.forward(parameterStore, out1Branch, training);
        x2In = this.lastLayer2UpSample.forward(parameterStore, x2In, training);
        x2In = new NDList(new NDArray[]{x2In.singletonOrThrow().concat(x2, 1)});
        NDList out2 = this.lastLayer2.forward(parameterStore, x2In, training);
        out2 = this.layer2Output.forward(parameterStore, out2, training);
        return new NDList(new NDArray[]{out0.singletonOrThrow(), out1.singletonOrThrow(), out2.singletonOrThrow()});
    }

    public Shape[] getOutputShapes(Shape[] inputShapes) {
        Shape[] current = inputShapes;
        Shape[] outputs = new Shape[3];
        Shape[] darkNetOutputs = new Shape[8];
        int index = 0;
        for (String name : this.children.keys()) {
            Block block = (Block)this.children.get((Object)name);
            if (name.contains("darkNet")) {
                darkNetOutputs = block.getOutputShapes(current);
                current = new Shape[]{darkNetOutputs[7]};
                continue;
            }
            if (name.contains("lastLayer")) {
                if ("05lastLayer1UpSample".equals(name)) {
                    current = block.getOutputShapes(current);
                    current = new Shape[]{new Shape(new long[]{current[0].get(0), current[0].get(1) + darkNetOutputs[6].get(1), current[0].get(2), current[0].get(3)})};
                    continue;
                }
                if ("09lastLayer2UpSample".equals(name)) {
                    current = block.getOutputShapes(current);
                    current = new Shape[]{new Shape(new long[]{current[0].get(0), current[0].get(1) + darkNetOutputs[5].get(1), current[0].get(2), current[0].get(3)})};
                    continue;
                }
                current = block.getOutputShapes(current);
                continue;
            }
            if (!name.contains("Output")) {
                current = block.getOutputShapes(current);
                continue;
            }
            Shape[] output = block.getOutputShapes(current);
            outputs[index++] = output[0];
        }
        return outputs;
    }

    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        Shape[] current = inputShapes;
        Shape[] darkNetOutputs = new Shape[8];
        for (String name : this.children.keys()) {
            Block block = (Block)this.children.get((Object)name);
            block.initialize(manager, dataType, current);
            if (name.contains("darkNet")) {
                darkNetOutputs = block.getOutputShapes(current);
                current = new Shape[]{darkNetOutputs[7]};
                continue;
            }
            if (name.contains("lastLayer")) {
                if ("05lastLayer1UpSample".equals(name)) {
                    current = block.getOutputShapes(current);
                    current = new Shape[]{new Shape(new long[]{current[0].get(0), current[0].get(1) + darkNetOutputs[6].get(1), current[0].get(2), current[0].get(3)})};
                    continue;
                }
                if ("09lastLayer2UpSample".equals(name)) {
                    current = block.getOutputShapes(current);
                    current = new Shape[]{new Shape(new long[]{current[0].get(0), current[0].get(1) + darkNetOutputs[5].get(1), current[0].get(2), current[0].get(3)})};
                    continue;
                }
                current = block.getOutputShapes(current);
                continue;
            }
            if (!name.contains("Output")) {
                current = block.getOutputShapes(current);
                continue;
            }
            block.getOutputShapes(current);
        }
    }

    public static Block basicBlock(int filters, float batchNormMomentum, float leakyAlpha) {
        SequentialBlock block = new SequentialBlock();
        block.add((Block)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(filters / 2)).setKernelShape(new Shape(new long[]{1L, 1L}))).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation.leakyReluBlock((float)leakyAlpha)).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(filters)).setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation.leakyReluBlock((float)leakyAlpha));
        return new ParallelBlock(list -> new NDList(new NDArray[]{NDArrays.add((NDArray[])new NDArray[]{((NDList)list.get(0)).singletonOrThrow(), ((NDList)list.get(1)).singletonOrThrow()})}), Arrays.asList(block, Blocks.identityBlock()));
    }

    public static Block makeLayer(int filters, int repeats, float batchNormMomentum, float leakyAlpha) {
        ArrayList<Block> layer = new ArrayList<Block>();
        SequentialBlock convolutionalLayer = new SequentialBlock();
        convolutionalLayer.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(filters)).setKernelShape(new Shape(new long[]{3L, 3L}))).optStride(new Shape(new long[]{2L, 2L}))).optPadding(new Shape(new long[]{1L, 1L}))).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation.leakyReluBlock((float)leakyAlpha));
        for (int i = 0; i < repeats; ++i) {
            layer.add(YOLOV3.basicBlock(filters, batchNormMomentum, leakyAlpha));
        }
        return new SequentialBlock().add((Block)convolutionalLayer).addAll(layer);
    }

    private static SequentialBlock darkNet53(Builder builder, boolean setReturnIntermediate) {
        SequentialBlock darkNet53 = new SequentialBlock();
        darkNet53.setReturnIntermediate(setReturnIntermediate);
        darkNet53.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(FILTERS[0])).optPadding(new Shape(new long[]{1L, 1L}))).setKernelShape(new Shape(new long[]{3L, 3L}))).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.leakyReluBlock((float)builder.leakyAlpha)).add(YOLOV3.makeLayer(FILTERS[1], REPEATS[0], builder.batchNormMomentum, builder.leakyAlpha)).add(YOLOV3.makeLayer(FILTERS[2], REPEATS[1], builder.batchNormMomentum, builder.leakyAlpha)).add(YOLOV3.makeLayer(FILTERS[3], REPEATS[2], builder.batchNormMomentum, builder.leakyAlpha)).add(YOLOV3.makeLayer(FILTERS[4], REPEATS[3], builder.batchNormMomentum, builder.leakyAlpha)).add(YOLOV3.makeLayer(FILTERS[5], REPEATS[4], builder.batchNormMomentum, builder.leakyAlpha));
        return darkNet53;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        int numClasses = 20;
        float batchNormMomentum = 0.9f;
        float leakyAlpha = 0.1f;
        int darkNetOutSize = 10;

        public Builder setNumClasses(int numClasses) {
            this.numClasses = numClasses;
            return this;
        }

        public Builder optBatchNormMomentum(float batchNormMomentum) {
            this.batchNormMomentum = batchNormMomentum;
            return this;
        }

        public Builder optLeakyAlpha(float leakyAlpha) {
            this.leakyAlpha = leakyAlpha;
            return this;
        }

        public Builder optDarkNetOutSize(int darkNetOutSize) {
            this.darkNetOutSize = darkNetOutSize;
            return this;
        }

        public Block build() {
            return new YOLOV3(this);
        }

        public Block buildDarkNet() {
            SequentialBlock block = YOLOV3.darkNet53(this, false);
            return new SequentialBlock().add((Block)block).add(Pool.globalAvgPool2dBlock()).add((Block)Linear.builder().setUnits((long)this.darkNetOutSize).build());
        }
    }
}

