/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.maxent.quasinewton.NegLogLikelihood;
import opennlp.tools.ml.model.DataIndexer;

public class ParallelNegLogLikelihood
extends NegLogLikelihood {
    private int threads;
    private double[] negLogLikelihoodThread;
    private double[][] gradientThread;

    public ParallelNegLogLikelihood(DataIndexer indexer, int threads) {
        super(indexer);
        if (threads <= 0) {
            throw new IllegalArgumentException("Number of threads must 1 or larger");
        }
        this.threads = threads;
        this.negLogLikelihoodThread = new double[threads];
        this.gradientThread = new double[threads][this.dimension];
    }

    @Override
    public double valueAt(double[] x) {
        if (x.length != this.dimension) {
            throw new IllegalArgumentException("x is invalid, its dimension is not equal to domain dimension.");
        }
        this.computeInParallel(x, NegLLComputeTask.class);
        double negLogLikelihood = 0.0;
        for (int t = 0; t < this.threads; ++t) {
            negLogLikelihood += this.negLogLikelihoodThread[t];
        }
        return negLogLikelihood;
    }

    @Override
    public double[] gradientAt(double[] x) {
        if (x.length != this.dimension) {
            throw new IllegalArgumentException("x is invalid, its dimension is not equal to the function.");
        }
        this.computeInParallel(x, GradientComputeTask.class);
        for (int i = 0; i < this.dimension; ++i) {
            this.gradient[i] = 0.0;
            for (int t = 0; t < this.threads; ++t) {
                int n = i;
                this.gradient[n] = this.gradient[n] + this.gradientThread[t][i];
            }
        }
        return this.gradient;
    }

    private void computeInParallel(double[] x, Class<? extends ComputeTask> taskClass) {
        ExecutorService executor = Executors.newFixedThreadPool(this.threads, runnable -> {
            Thread thread = new Thread(runnable);
            thread.setName("opennlp.tools.ml.maxent.quasinewton.ParallelNegLogLikelihood.computeInParallel()");
            thread.setDaemon(true);
            return thread;
        });
        int taskSize = this.numContexts / this.threads;
        int leftOver = this.numContexts % this.threads;
        try {
            Constructor<? extends ComputeTask> cons = taskClass.getConstructor(ParallelNegLogLikelihood.class, Integer.TYPE, Integer.TYPE, Integer.TYPE, double[].class);
            ArrayList futures = new ArrayList();
            for (int i = 0; i < this.threads; ++i) {
                if (i != this.threads - 1) {
                    futures.add(executor.submit(cons.newInstance(this, i, i * taskSize, taskSize, x)));
                    continue;
                }
                futures.add(executor.submit(cons.newInstance(this, i, i * taskSize, taskSize + leftOver, x)));
            }
            for (Future future : futures) {
                future.get();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        executor.shutdown();
    }

    class GradientComputeTask
    extends ComputeTask {
        final double[] expectation;

        public GradientComputeTask(int threadIndex, int startIndex, int length, double[] x) {
            super(threadIndex, startIndex, length, x);
            this.expectation = new double[ParallelNegLogLikelihood.this.numOutcomes];
        }

        @Override
        public GradientComputeTask call() {
            Arrays.fill(ParallelNegLogLikelihood.this.gradientThread[this.threadIndex], 0.0);
            for (int ci = this.startIndex; ci < this.startIndex + this.length; ++ci) {
                double predValue;
                int vectorIndex;
                int ai;
                int oi;
                for (oi = 0; oi < ParallelNegLogLikelihood.this.numOutcomes; ++oi) {
                    this.expectation[oi] = 0.0;
                    for (ai = 0; ai < ParallelNegLogLikelihood.this.contexts[ci].length; ++ai) {
                        vectorIndex = ParallelNegLogLikelihood.this.indexOf(oi, ParallelNegLogLikelihood.this.contexts[ci][ai]);
                        predValue = ParallelNegLogLikelihood.this.values != null ? (double)ParallelNegLogLikelihood.this.values[ci][ai] : 1.0;
                        int n = oi;
                        this.expectation[n] = this.expectation[n] + predValue * this.x[vectorIndex];
                    }
                }
                double logSumOfExps = ArrayMath.logSumOfExps(this.expectation);
                for (oi = 0; oi < ParallelNegLogLikelihood.this.numOutcomes; ++oi) {
                    this.expectation[oi] = StrictMath.exp(this.expectation[oi] - logSumOfExps);
                }
                for (oi = 0; oi < ParallelNegLogLikelihood.this.numOutcomes; ++oi) {
                    boolean empirical = ParallelNegLogLikelihood.this.outcomeList[ci] == oi;
                    for (ai = 0; ai < ParallelNegLogLikelihood.this.contexts[ci].length; ++ai) {
                        vectorIndex = ParallelNegLogLikelihood.this.indexOf(oi, ParallelNegLogLikelihood.this.contexts[ci][ai]);
                        predValue = ParallelNegLogLikelihood.this.values != null ? (double)ParallelNegLogLikelihood.this.values[ci][ai] : 1.0;
                        double[] dArray = ParallelNegLogLikelihood.this.gradientThread[this.threadIndex];
                        int n = vectorIndex;
                        dArray[n] = dArray[n] + predValue * (this.expectation[oi] - (double)empirical) * (double)ParallelNegLogLikelihood.this.numTimesEventsSeen[ci];
                    }
                }
            }
            return this;
        }
    }

    class NegLLComputeTask
    extends ComputeTask {
        final double[] tempSums;

        public NegLLComputeTask(int threadIndex, int startIndex, int length, double[] x) {
            super(threadIndex, startIndex, length, x);
            this.tempSums = new double[ParallelNegLogLikelihood.this.numOutcomes];
        }

        @Override
        public NegLLComputeTask call() {
            ((ParallelNegLogLikelihood)ParallelNegLogLikelihood.this).negLogLikelihoodThread[this.threadIndex] = 0.0;
            for (int ci = this.startIndex; ci < this.startIndex + this.length; ++ci) {
                for (int oi = 0; oi < ParallelNegLogLikelihood.this.numOutcomes; ++oi) {
                    this.tempSums[oi] = 0.0;
                    for (int ai = 0; ai < ParallelNegLogLikelihood.this.contexts[ci].length; ++ai) {
                        int vectorIndex = ParallelNegLogLikelihood.this.indexOf(oi, ParallelNegLogLikelihood.this.contexts[ci][ai]);
                        double predValue = ParallelNegLogLikelihood.this.values != null ? (double)ParallelNegLogLikelihood.this.values[ci][ai] : 1.0;
                        int n = oi;
                        this.tempSums[n] = this.tempSums[n] + predValue * this.x[vectorIndex];
                    }
                }
                double logSumOfExps = ArrayMath.logSumOfExps(this.tempSums);
                int outcome = ParallelNegLogLikelihood.this.outcomeList[ci];
                double[] dArray = ParallelNegLogLikelihood.this.negLogLikelihoodThread;
                int n = this.threadIndex;
                dArray[n] = dArray[n] - (this.tempSums[outcome] - logSumOfExps) * (double)ParallelNegLogLikelihood.this.numTimesEventsSeen[ci];
            }
            return this;
        }
    }

    abstract class ComputeTask
    implements Callable<ComputeTask> {
        final int threadIndex;
        final int startIndex;
        final int length;
        final double[] x;

        public ComputeTask(int threadIndex, int startIndex, int length, double[] x) {
            this.threadIndex = threadIndex;
            this.startIndex = startIndex;
            this.length = length;
            this.x = x;
        }
    }
}

