package org.nd4j.linalg.workspace;

import java.lang.Enum;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/workspace/BaseWorkspaceMgr.class */
public abstract class BaseWorkspaceMgr<T extends Enum<T>> implements WorkspaceMgr<T> {
    private static final Logger log = LoggerFactory.getLogger(BaseWorkspaceMgr.class);
    private static final boolean DISABLE_LEVERAGE = false;
    protected final Set<T> scopeOutOfWs;
    protected final Map<T, WorkspaceConfiguration> configMap;
    protected final Map<T, String> workspaceNames;

    protected BaseWorkspaceMgr(Set<T> set, Map<T, WorkspaceConfiguration> map, Map<T, String> map2) {
        this.scopeOutOfWs = set;
        this.configMap = map;
        this.workspaceNames = map2;
    }

    protected BaseWorkspaceMgr() {
        this.scopeOutOfWs = new HashSet();
        this.configMap = new HashMap();
        this.workspaceNames = new HashMap();
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void setConfiguration(@NonNull T t, WorkspaceConfiguration workspaceConfiguration) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        this.configMap.put(t, workspaceConfiguration);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public WorkspaceConfiguration getConfiguration(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        return this.configMap.get(t);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void setScopedOutFor(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        this.scopeOutOfWs.add(t);
        this.configMap.remove(t);
        this.workspaceNames.remove(t);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public boolean isScopedOut(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        return this.scopeOutOfWs.contains(t);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public boolean hasConfiguration(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        return this.scopeOutOfWs.contains(t) || this.workspaceNames.containsKey(t);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public MemoryWorkspace notifyScopeEntered(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        validateConfig(t);
        return isScopedOut(t) ? Nd4j.getWorkspaceManager().scopeOutOfWorkspaces() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(getConfiguration(t), getWorkspaceName(t)).notifyScopeEntered();
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public WorkspacesCloseable notifyScopeEntered(@NonNull T... tArr) {
        if (tArr == null) {
            throw new NullPointerException("arrayTypes");
        }
        MemoryWorkspace[] memoryWorkspaceArr = new MemoryWorkspace[tArr.length];
        for (int i = 0; i < tArr.length; i++) {
            memoryWorkspaceArr[i] = notifyScopeEntered((BaseWorkspaceMgr<T>) tArr[i]);
        }
        return new WorkspacesCloseable(memoryWorkspaceArr);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public MemoryWorkspace notifyScopeBorrowed(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        validateConfig(t);
        enforceExistsAndActive(t);
        return this.scopeOutOfWs.contains(t) ? Nd4j.getWorkspaceManager().scopeOutOfWorkspaces() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(getConfiguration(t), getWorkspaceName(t)).notifyScopeBorrowed();
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void setWorkspaceName(@NonNull T t, @NonNull String str) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (str == null) {
            throw new NullPointerException("name");
        }
        this.workspaceNames.put(t, str);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public String getWorkspaceName(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        return this.workspaceNames.get(t);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void setWorkspace(@NonNull T t, @NonNull String str, @NonNull WorkspaceConfiguration workspaceConfiguration) {
        if (t == null) {
            throw new NullPointerException("forEnum");
        }
        if (str == null) {
            throw new NullPointerException("wsName");
        }
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        if (this.scopeOutOfWs.contains(t)) {
            this.scopeOutOfWs.remove(t);
        }
        setWorkspaceName(t, str);
        setConfiguration(t, workspaceConfiguration);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public boolean isWorkspaceOpen(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        validateConfig(t);
        if (this.scopeOutOfWs.contains(t)) {
            return true;
        }
        return Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(getWorkspaceName(t));
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void assertOpen(T t, String str) throws ND4JWorkspaceException {
        if (!this.scopeOutOfWs.contains(t) && !isWorkspaceOpen(t)) {
            throw new ND4JWorkspaceException("Assertion failed: expected workspace for array type " + t + " to be open: " + str);
        }
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void assertNotOpen(@NonNull T t, @NonNull String str) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (str == null) {
            throw new NullPointerException("msg");
        }
        if (!this.scopeOutOfWs.contains(t) && isWorkspaceOpen(t)) {
            throw new ND4JWorkspaceException("Assertion failed: expected workspace for array type " + t + " to not be open: " + str);
        }
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public void assertCurrentWorkspace(@NonNull T t, String str) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        validateConfig(t);
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (this.scopeOutOfWs.contains(t)) {
            return;
        }
        if (currentWorkspace == null || !getWorkspaceName(t).equals(currentWorkspace.getId())) {
            throw new ND4JWorkspaceException("Assertion failed: expected current workspace to be \"" + getWorkspaceName(t) + "\" (for array type " + t + ") - actual current workspace is " + (currentWorkspace == null ? null : currentWorkspace.getId()) + (str == null ? "" : ": " + str));
        }
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray leverageTo(@NonNull T t, @NonNull INDArray iNDArray) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iNDArray == null) {
            throw new NullPointerException("array");
        }
        if (iNDArray == null || !iNDArray.isAttached()) {
            return iNDArray;
        }
        validateConfig(t);
        enforceExistsAndActive(t);
        return this.scopeOutOfWs.contains(t) ? iNDArray.detach() : iNDArray.leverageTo(getWorkspaceName(t), true);
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray validateArrayLocation(@NonNull T t, @NonNull INDArray iNDArray, boolean z, boolean z2) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iNDArray == null) {
            throw new NullPointerException("array");
        }
        validateConfig(t);
        if (this.scopeOutOfWs.contains(t)) {
            if (!iNDArray.isAttached()) {
                return iNDArray;
            }
            if (z) {
                return leverageTo(t, iNDArray);
            }
            throw new ND4JWorkspaceException("Array workspace validation failed: Array of type " + t + " should be detached (no workspace) but is in workspace: " + iNDArray.data().getParentWorkspace().getId());
        }
        String workspaceName = getWorkspaceName(t);
        if (!iNDArray.isAttached()) {
            if (z2) {
                throw new ND4JWorkspaceException("Array workspace validation failed: Array of type " + t + " should be in workspace \"" + workspaceName + "\" but is detached");
            }
            return iNDArray;
        }
        String id = iNDArray.data().getParentWorkspace().getId();
        if (workspaceName.equals(id)) {
            return iNDArray;
        }
        if (z) {
            return leverageTo(t, iNDArray);
        }
        throw new ND4JWorkspaceException("Array workspace validation failed: Array of type " + t + " should be in workspace \"" + workspaceName + "\" but is in workspace \"" + id + "\"");
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray create(@NonNull T t, @NonNull int... iArr) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iArr == null) {
            throw new NullPointerException(TFGraphMapper.SHAPE_KEY);
        }
        enforceExistsAndActive(t);
        return create(t, iArr, Nd4j.order().charValue());
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray create(@NonNull T t, @NonNull int[] iArr, @NonNull char c) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iArr == null) {
            throw new NullPointerException(TFGraphMapper.SHAPE_KEY);
        }
        enforceExistsAndActive(t);
        MemoryWorkspace notifyScopeBorrowed = notifyScopeBorrowed(t);
        Throwable th = null;
        try {
            try {
                INDArray create = Nd4j.create(iArr, c);
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                return create;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeBorrowed != null) {
                if (th != null) {
                    try {
                        notifyScopeBorrowed.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeBorrowed.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray createUninitialized(@NonNull T t, @NonNull int... iArr) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iArr == null) {
            throw new NullPointerException(TFGraphMapper.SHAPE_KEY);
        }
        return createUninitialized(t, iArr, Nd4j.order().charValue());
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray createUninitialized(@NonNull T t, @NonNull int[] iArr, char c) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iArr == null) {
            throw new NullPointerException(TFGraphMapper.SHAPE_KEY);
        }
        enforceExistsAndActive(t);
        MemoryWorkspace notifyScopeBorrowed = notifyScopeBorrowed(t);
        Throwable th = null;
        try {
            try {
                INDArray createUninitialized = Nd4j.createUninitialized(iArr, c);
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                return createUninitialized;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeBorrowed != null) {
                if (th != null) {
                    try {
                        notifyScopeBorrowed.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeBorrowed.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray dup(@NonNull T t, @NonNull INDArray iNDArray, char c) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iNDArray == null) {
            throw new NullPointerException("toDup");
        }
        enforceExistsAndActive(t);
        MemoryWorkspace notifyScopeBorrowed = notifyScopeBorrowed(t);
        Throwable th = null;
        try {
            try {
                INDArray dup = iNDArray.dup(c);
                if (notifyScopeBorrowed != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeBorrowed.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        notifyScopeBorrowed.close();
                    }
                }
                return dup;
            } finally {
            }
        } catch (Throwable th3) {
            if (notifyScopeBorrowed != null) {
                if (th != null) {
                    try {
                        notifyScopeBorrowed.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    notifyScopeBorrowed.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.nd4j.linalg.workspace.WorkspaceMgr
    public INDArray dup(@NonNull T t, @NonNull INDArray iNDArray) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (iNDArray == null) {
            throw new NullPointerException("toDup");
        }
        return dup(t, iNDArray, iNDArray.ordering());
    }

    private void validateConfig(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        if (this.scopeOutOfWs.contains(t)) {
            return;
        }
        if (!this.configMap.containsKey(t)) {
            throw new ND4JWorkspaceException("No workspace configuration has been provided for arrayType: " + t);
        }
        if (!this.workspaceNames.containsKey(t)) {
            throw new ND4JWorkspaceException("No workspace name has been provided for arrayType: " + t);
        }
    }

    private void enforceExistsAndActive(@NonNull T t) {
        if (t == null) {
            throw new NullPointerException("arrayType");
        }
        validateConfig(t);
        if (!this.scopeOutOfWs.contains(t) && !Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(this.workspaceNames.get(t))) {
            throw new ND4JWorkspaceException("Workspace \"" + this.workspaceNames.get(t) + "\" for array type " + t + " is not open");
        }
    }
}
