package org.nd4j.linalg.memory.provider;

import java.lang.ref.ReferenceQueue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.MemoryWorkspaceManager;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.memory.abstracts.Nd4jWorkspace;
import org.nd4j.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/memory/provider/BasicWorkspaceManager.class */
public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager {
    private static final Logger log = LoggerFactory.getLogger(BasicWorkspaceManager.class);
    protected AtomicLong counter;
    protected WorkspaceConfiguration defaultConfiguration;
    protected ThreadLocal<Map<String, MemoryWorkspace>> backingMap;
    private ReferenceQueue<MemoryWorkspace> queue;
    private WorkspaceDeallocatorThread thread;
    private Map<String, Nd4jWorkspace.GarbageWorkspaceReference> referenceMap;

    /* loaded from: input_file:org/nd4j/linalg/memory/provider/BasicWorkspaceManager$WorkspaceDeallocatorThread.class */
    protected class WorkspaceDeallocatorThread extends Thread implements Runnable {
        private final ReferenceQueue<MemoryWorkspace> queue;

        protected WorkspaceDeallocatorThread(ReferenceQueue<MemoryWorkspace> referenceQueue) {
            this.queue = referenceQueue;
            setDaemon(true);
            setName("Workspace deallocator thread");
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    Nd4jWorkspace.GarbageWorkspaceReference garbageWorkspaceReference = (Nd4jWorkspace.GarbageWorkspaceReference) this.queue.remove();
                    if (garbageWorkspaceReference != null) {
                        PointersPair pointersPair = garbageWorkspaceReference.getPointersPair();
                        if (pointersPair != null) {
                            if (pointersPair.getDevicePointer() != null) {
                                Nd4j.getMemoryManager().release(pointersPair.getDevicePointer(), MemoryKind.DEVICE);
                            }
                            if (pointersPair.getHostPointer() != null) {
                                Nd4j.getMemoryManager().release(pointersPair.getHostPointer(), MemoryKind.HOST);
                            }
                        }
                        for (PointersPair pointersPair2 : garbageWorkspaceReference.getExternalPointers()) {
                            if (pointersPair2 != null) {
                                if (pointersPair2.getHostPointer() != null) {
                                    Nd4j.getMemoryManager().release(pointersPair2.getHostPointer(), MemoryKind.HOST);
                                }
                                if (pointersPair2.getDevicePointer() != null) {
                                    Nd4j.getMemoryManager().release(pointersPair2.getDevicePointer(), MemoryKind.DEVICE);
                                }
                            }
                        }
                        while (true) {
                            PointersPair poll = garbageWorkspaceReference.getPinnedPointers().poll();
                            if (poll == null) {
                                break;
                            }
                            if (poll.getHostPointer() != null) {
                                Nd4j.getMemoryManager().release(poll.getHostPointer(), MemoryKind.HOST);
                            }
                            if (poll.getDevicePointer() != null) {
                                Nd4j.getMemoryManager().release(poll.getDevicePointer(), MemoryKind.DEVICE);
                            }
                        }
                        BasicWorkspaceManager.this.referenceMap.remove(garbageWorkspaceReference.getKey());
                    }
                } catch (InterruptedException e) {
                    return;
                } catch (Exception e2) {
                }
            }
        }
    }

    public BasicWorkspaceManager() {
        this(WorkspaceConfiguration.builder().initialSize(0L).maxSize(0L).overallocationLimit(0.3d).policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.EXTERNAL).build());
    }

    public BasicWorkspaceManager(@NonNull WorkspaceConfiguration workspaceConfiguration) {
        this.counter = new AtomicLong();
        this.backingMap = new ThreadLocal<>();
        this.referenceMap = new ConcurrentHashMap();
        if (workspaceConfiguration == null) {
            throw new NullPointerException("defaultConfiguration");
        }
        this.defaultConfiguration = workspaceConfiguration;
        this.queue = new ReferenceQueue<>();
        this.thread = new WorkspaceDeallocatorThread(this.queue);
        this.thread.start();
    }

    public String getUUID() {
        return "Workspace_" + String.valueOf(this.counter.incrementAndGet());
    }

    public void setDefaultWorkspaceConfiguration(@NonNull WorkspaceConfiguration workspaceConfiguration) {
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        this.defaultConfiguration = workspaceConfiguration;
    }

    public MemoryWorkspace getWorkspaceForCurrentThread() {
        return getWorkspaceForCurrentThread("DefaultWorkspace");
    }

    public MemoryWorkspace getWorkspaceForCurrentThread(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("id");
        }
        return getWorkspaceForCurrentThread(this.defaultConfiguration, str);
    }

    protected void pickReference(MemoryWorkspace memoryWorkspace) {
        Nd4jWorkspace.GarbageWorkspaceReference garbageWorkspaceReference = new Nd4jWorkspace.GarbageWorkspaceReference(memoryWorkspace, this.queue);
        this.referenceMap.put(garbageWorkspaceReference.getKey(), garbageWorkspaceReference);
    }

    public void setWorkspaceForCurrentThread(MemoryWorkspace memoryWorkspace) {
        setWorkspaceForCurrentThread(memoryWorkspace, "DefaultWorkspace");
    }

    public void setWorkspaceForCurrentThread(@NonNull MemoryWorkspace memoryWorkspace, @NonNull String str) {
        if (memoryWorkspace == null) {
            throw new NullPointerException("workspace");
        }
        if (str == null) {
            throw new NullPointerException("id");
        }
        ensureThreadExistense();
        this.backingMap.get().put(str, memoryWorkspace);
    }

    public void destroyWorkspace(MemoryWorkspace memoryWorkspace) {
        if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace)) {
            return;
        }
        this.backingMap.get().remove(memoryWorkspace.getId());
    }

    public void destroyWorkspace() {
        ensureThreadExistense();
        this.backingMap.get().get("DefaultWorkspace");
        this.backingMap.get().remove("DefaultWorkspace");
    }

    public void destroyAllWorkspacesForCurrentThread() {
        ensureThreadExistense();
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.backingMap.get().values());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            destroyWorkspace((MemoryWorkspace) it.next());
        }
        System.gc();
    }

    protected void ensureThreadExistense() {
        if (this.backingMap.get() == null) {
            this.backingMap.set(new HashMap());
        }
    }

    public MemoryWorkspace getAndActivateWorkspace() {
        return getWorkspaceForCurrentThread().notifyScopeEntered();
    }

    public MemoryWorkspace getAndActivateWorkspace(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("id");
        }
        return getWorkspaceForCurrentThread(str).notifyScopeEntered();
    }

    public MemoryWorkspace getAndActivateWorkspace(@NonNull WorkspaceConfiguration workspaceConfiguration, @NonNull String str) {
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        if (str == null) {
            throw new NullPointerException("id");
        }
        return getWorkspaceForCurrentThread(workspaceConfiguration, str).notifyScopeEntered();
    }

    public boolean checkIfWorkspaceExists(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("id");
        }
        ensureThreadExistense();
        return this.backingMap.get().containsKey(str);
    }

    public boolean checkIfWorkspaceExistsAndActive(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("id");
        }
        if (checkIfWorkspaceExists(str)) {
            return this.backingMap.get().get(str).isScopeActive();
        }
        return false;
    }

    public MemoryWorkspace scopeOutOfWorkspaces() {
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (currentWorkspace == null) {
            return new DummyWorkspace();
        }
        Nd4j.getMemoryManager().setCurrentWorkspace(null);
        return currentWorkspace.tagOutOfScopeUse();
    }

    public synchronized void printAllocationStatisticsForCurrentThread() {
        ensureThreadExistense();
        Map<String, MemoryWorkspace> map = this.backingMap.get();
        log.info("Workspace statistics: ---------------------------------");
        log.info("Number of workspaces in current thread: {}", Integer.valueOf(map.size()));
        log.info("Workspace name: Allocated / external (spilled) / external (pinned)");
        for (String str : map.keySet()) {
            long currentSize = ((Nd4jWorkspace) map.get(str)).getCurrentSize();
            long spilledSize = ((Nd4jWorkspace) map.get(str)).getSpilledSize();
            long pinnedSize = ((Nd4jWorkspace) map.get(str)).getPinnedSize();
            log.info(String.format("%-26s %8s / %8s / %8s (%11d / %11d / %11d)", str + ":", StringUtils.TraditionalBinaryPrefix.long2String(currentSize, "", 2), StringUtils.TraditionalBinaryPrefix.long2String(spilledSize, "", 2), StringUtils.TraditionalBinaryPrefix.long2String(pinnedSize, "", 2), Long.valueOf(currentSize), Long.valueOf(spilledSize), Long.valueOf(pinnedSize)));
        }
    }

    public List<String> getAllWorkspacesIdsForCurrentThread() {
        ensureThreadExistense();
        return new ArrayList(this.backingMap.get().keySet());
    }

    public List<MemoryWorkspace> getAllWorkspacesForCurrentThread() {
        ensureThreadExistense();
        return new ArrayList(this.backingMap.get().values());
    }

    public boolean anyWorkspaceActiveForCurrentThread() {
        ensureThreadExistense();
        boolean z = false;
        Iterator<MemoryWorkspace> it = this.backingMap.get().values().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (it.next().isScopeActive()) {
                z = true;
                break;
            }
        }
        return z;
    }
}
