package org.nd4j.linalg.api.ndarray;

import com.google.common.primitives.Ints;
import java.util.ArrayList;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.ShapeOffsetResolution;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.util.LongUtils;

/* loaded from: input_file:org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.class */
public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
    protected static final SparseFormat format = SparseFormat.CSR;
    protected volatile transient DataBuffer values;
    protected volatile transient DataBuffer columnsPointers;
    protected volatile transient DataBuffer pointerB;
    protected volatile transient DataBuffer pointerE;

    public BaseSparseNDArrayCSR(double[] dArr, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        Preconditions.checkArgument(dArr.length == iArr.length);
        Preconditions.checkArgument(iArr2.length == iArr3.length);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(iArr4));
        init(iArr4);
        int length = (int) (dArr.length * 2.0d);
        this.values = Nd4j.getDataBufferFactory().createDouble(length);
        this.values.setData(dArr);
        this.columnsPointers = Nd4j.getDataBufferFactory().createInt(length);
        this.columnsPointers.setData(iArr);
        this.length = iArr.length;
        int i = this.rows;
        this.pointerB = Nd4j.getDataBufferFactory().createInt(i);
        this.pointerB.setData(iArr2);
        this.pointerE = Nd4j.getDataBufferFactory().createInt(i);
        this.pointerE.setData(iArr3);
    }

    public BaseSparseNDArrayCSR(float[] fArr, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        this(Nd4j.createBuffer(fArr), iArr, iArr2, iArr3, iArr4);
    }

    public BaseSparseNDArrayCSR(DataBuffer dataBuffer, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        Preconditions.checkArgument(iArr2.length == iArr3.length);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(iArr4));
        init(iArr4);
        this.values = dataBuffer;
        this.columnsPointers = Nd4j.getDataBufferFactory().createInt(dataBuffer.length());
        this.columnsPointers.setData(iArr);
        this.length = iArr.length;
        int i = this.rows;
        this.pointerB = Nd4j.getDataBufferFactory().createInt(i);
        this.pointerB.setData(iArr2);
        this.pointerE = Nd4j.getDataBufferFactory().createInt(i);
        this.pointerE.setData(iArr3);
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public INDArray putScalar(int i, int i2, double d) {
        Preconditions.checkArgument(i < this.rows && 0 <= this.rows);
        Preconditions.checkArgument(i2 < this.columns && 0 <= this.columns);
        int i3 = this.pointerB.getInt(i);
        int i4 = this.pointerE.getInt(i);
        while (this.columnsPointers.getInt(i3) < i2 && this.columnsPointers.getInt(i3) < i4) {
            i3++;
        }
        if (this.columnsPointers.getInt(i3) == i2) {
            this.values.put(i3, d);
        } else {
            this.values = addAtPosition(this.values, this.length, i3, d);
            this.columnsPointers = addAtPosition(this.columnsPointers, this.length, i3, i2);
            this.length++;
            this.pointerE.put(i, this.pointerE.getInt(i) + 1);
            for (int i5 = i + 1; i5 < this.rows; i5++) {
                this.pointerB.put(i5, this.pointerB.getInt(i5) + 1);
                this.pointerE.put(i5, this.pointerE.getInt(i5) + 1);
            }
        }
        return this;
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray, org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.complex.IComplexNDArray
    public INDArray get(INDArrayIndex... iNDArrayIndexArr) {
        if ((iNDArrayIndexArr.length == 1 && (iNDArrayIndexArr[0] instanceof NDArrayIndexAll)) || (iNDArrayIndexArr.length == 2 && ((isRowVector() && (iNDArrayIndexArr[0] instanceof PointIndex) && iNDArrayIndexArr[0].offset() == 0 && (iNDArrayIndexArr[1] instanceof NDArrayIndexAll)) || (isColumnVector() && (iNDArrayIndexArr[1] instanceof PointIndex) && iNDArrayIndexArr[0].offset() == 0 && (iNDArrayIndexArr[0] instanceof NDArrayIndexAll))))) {
            return this;
        }
        INDArrayIndex[] resolve = NDArrayIndex.resolve(shapeInfoDataBuffer(), iNDArrayIndexArr);
        ShapeOffsetResolution shapeOffsetResolution = new ShapeOffsetResolution(this);
        shapeOffsetResolution.exec(resolve);
        if (resolve.length < 1) {
            throw new IllegalStateException("Invalid index found of zero length");
        }
        int[] ints = LongUtils.toInts(shapeOffsetResolution.getShapes());
        int i = 0;
        for (INDArrayIndex iNDArrayIndex : resolve) {
            if (iNDArrayIndex instanceof SpecifiedIndex) {
                i++;
            }
        }
        if (ints == null || i <= 0) {
            return subArray(shapeOffsetResolution);
        }
        return null;
    }

    @Override // org.nd4j.linalg.api.ndarray.ISparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public DataBuffer getVectorCoordinates() {
        return Nd4j.getDataBufferFactory().create(this.columnsPointers, 0L, length());
    }

    public double[] getDoubleValues() {
        return this.values.getDoublesAt(0L, (int) this.length);
    }

    public double[] getColumns() {
        return this.columnsPointers.getDoublesAt(0L, (int) this.length);
    }

    public int[] getPointerBArray() {
        return this.pointerB.asInt();
    }

    public int[] getPointerEArray() {
        return this.pointerE.asInt();
    }

    @Override // org.nd4j.linalg.api.ndarray.ISparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public SparseFormat getFormat() {
        return format;
    }

    public DataBuffer getPointerB() {
        return Nd4j.getDataBufferFactory().create(this.pointerB, 0L, rows());
    }

    public DataBuffer getPointerE() {
        return Nd4j.getDataBufferFactory().create(this.pointerE, 0L, rows());
    }

    private DataBuffer addAtPosition(DataBuffer dataBuffer, long j, int i, double d) {
        DataBuffer reallocate = dataBuffer.length() == j ? reallocate(dataBuffer) : dataBuffer;
        double[] doublesAt = reallocate.getDoublesAt(i, ((int) j) - i);
        reallocate.put(i, d);
        for (int i2 = 0; i2 < doublesAt.length; i2++) {
            reallocate.put(i2 + i + 1, doublesAt[i2]);
        }
        return reallocate;
    }

    @Override // org.nd4j.linalg.api.ndarray.INDArray
    public DataBuffer data() {
        return Nd4j.getDataBufferFactory().create(this.values, 0L, length());
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray, org.nd4j.linalg.api.ndarray.ISparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public INDArray toDense() {
        INDArray zeros = Nd4j.zeros(shape());
        int[] asInt = this.pointerB.asInt();
        int[] asInt2 = this.pointerE.asInt();
        for (int i = 0; i < rows(); i++) {
            for (int i2 = asInt[i]; i2 < asInt2[i]; i2++) {
                zeros.put(i, this.columnsPointers.getInt(i2), this.values.getNumber(i2));
            }
        }
        return zeros;
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public DataBuffer shapeInfoDataBuffer() {
        return this.shapeInformation;
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public INDArray subArray(ShapeOffsetResolution shapeOffsetResolution) {
        long[] offsets = shapeOffsetResolution.getOffsets();
        int[] ints = LongUtils.toInts(shapeOffsetResolution.getShapes());
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        if (ints.length != 2) {
            throw new UnsupportedOperationException();
        }
        if (shapeOffsetResolution.getOffset() != 0) {
            offsets[0] = ((int) shapeOffsetResolution.getOffset()) / shape()[1];
            offsets[1] = ((int) shapeOffsetResolution.getOffset()) % shape()[1];
        }
        long j = offsets[0];
        long j2 = j + ints[0];
        long j3 = offsets[1];
        long j4 = j3 + ints[1];
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < j2; i3++) {
            boolean z = true;
            for (int i4 = this.pointerB.getInt(i3); i4 < this.pointerE.getInt(i3); i4++) {
                int i5 = this.columnsPointers.getInt(i);
                if (i5 >= j3 && i5 < j4 && i3 >= j && i3 < j2) {
                    arrayList.add(Integer.valueOf((int) (i5 - j3)));
                    if (z) {
                        arrayList2.add(Integer.valueOf(i4));
                        arrayList3.add(Integer.valueOf(i4 + 1));
                        z = false;
                    } else {
                        arrayList3.set((int) (i3 - j), Integer.valueOf(i4 + 1));
                    }
                }
                i++;
            }
            if (z && i3 >= j && i3 < j2) {
                int intValue = i2 == 0 ? 0 : ((Integer) arrayList3.get(i2 - 1)).intValue();
                arrayList2.add(Integer.valueOf(intValue));
                arrayList3.add(Integer.valueOf(intValue));
            }
            if (i3 >= j && i3 <= j2) {
                i2++;
            }
        }
        return Nd4j.createSparseCSR(this.values, Ints.toArray(arrayList), Ints.toArray(arrayList2), Ints.toArray(arrayList3), ints);
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray, org.nd4j.linalg.api.ndarray.INDArray
    public INDArray subArray(long[] jArr, int[] iArr, int[] iArr2) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ndarray.BaseSparseNDArray
    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof INDArray)) {
            return false;
        }
        INDArray iNDArray = (INDArray) obj;
        if (!iNDArray.isSparse()) {
            return toDense().equals(obj);
        }
        BaseSparseNDArray baseSparseNDArray = (BaseSparseNDArray) iNDArray;
        switch (baseSparseNDArray.getFormat()) {
            case CSR:
                BaseSparseNDArrayCSR baseSparseNDArrayCSR = (BaseSparseNDArrayCSR) baseSparseNDArray;
                return baseSparseNDArrayCSR.rows() == rows() && baseSparseNDArrayCSR.columns() == columns() && baseSparseNDArrayCSR.getVectorCoordinates().equals(getVectorCoordinates()) && baseSparseNDArrayCSR.data().equals(data()) && baseSparseNDArrayCSR.getPointerB().equals(getPointerB()) && baseSparseNDArrayCSR.getPointerE().equals(getPointerE());
            default:
                return toDense().equals(baseSparseNDArray.toDense());
        }
    }

    @Override // org.nd4j.linalg.api.ndarray.INDArray
    public boolean isView() {
        return false;
    }

    @Override // org.nd4j.linalg.api.ndarray.INDArray
    public int underlyingRank() {
        return this.rank;
    }

    @Override // org.nd4j.linalg.api.ndarray.INDArray
    public INDArray putiColumnVector(INDArray iNDArray) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ndarray.INDArray
    public INDArray putiRowVector(INDArray iNDArray) {
        return null;
    }
}
