/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.GELattice;
import cc.mallet.fst.semi_supervised.GELatticeTask;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.SumLatticeTask;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

public class CRFOptimizableByGE
implements Optimizable.ByGradientValue {
    private static final int DEFAULT_GPV = 10;
    private CRF crf;
    private ArrayList<GEConstraint> constraints;
    private InstanceList data;
    private int numThreads;
    private double gpv;
    private double weight;
    private int cache;
    private double cachedValue;
    private CRF.Factors cachedGradient;
    private int[][] reverseTrans;
    private int[][] reverseTransIndices;
    private BitSet instancesWithConstraints;
    private ThreadPoolExecutor executor;

    public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> constraints, InstanceList data, StateLabelMap map, int numThreads) {
        this(crf, constraints, data, map, numThreads, 1.0);
    }

    public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> constraints, InstanceList data, StateLabelMap map, int numThreads, double weight) {
        this.crf = crf;
        this.constraints = constraints;
        this.cache = Integer.MAX_VALUE;
        this.cachedValue = Double.NaN;
        this.cachedGradient = new CRF.Factors(crf);
        this.data = data;
        this.numThreads = numThreads;
        this.weight = weight;
        this.instancesWithConstraints = new BitSet(data.size());
        for (GEConstraint constraint : constraints) {
            constraint.setStateLabelMap(map);
            BitSet bitset = constraint.preProcess(data);
            this.instancesWithConstraints.or(bitset);
        }
        this.gpv = 10.0;
        if (numThreads > 1) {
            this.executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads);
        }
        this.createReverseTransitionMatrices(crf);
    }

    public void createReverseTransitionMatrices(CRF crf) {
        int[] counts = new int[crf.numStates()];
        for (int si = 0; si < crf.numStates(); ++si) {
            CRF.State prevState = (CRF.State)crf.getState(si);
            for (int di = 0; di < prevState.numDestinations(); ++di) {
                int sj;
                int n = sj = prevState.getDestinationState(di).getIndex();
                counts[n] = counts[n] + 1;
            }
        }
        this.reverseTrans = new int[crf.numStates()][];
        this.reverseTransIndices = new int[crf.numStates()][];
        for (int i = 0; i < counts.length; ++i) {
            this.reverseTrans[i] = new int[counts[i]];
            this.reverseTransIndices[i] = new int[counts[i]];
        }
        int[] indices = new int[crf.numStates()];
        for (int si = 0; si < crf.numStates(); ++si) {
            CRF.State prevState = (CRF.State)crf.getState(si);
            int di = 0;
            while (di < prevState.numDestinations()) {
                int sj = prevState.getDestinationState(di).getIndex();
                this.reverseTrans[sj][indices[sj]] = si;
                this.reverseTransIndices[sj][indices[sj]] = di++;
                int n = sj;
                indices[n] = indices[n] + 1;
            }
        }
    }

    @Override
    public int getNumParameters() {
        return this.crf.getNumParameters();
    }

    @Override
    public void getParameters(double[] buffer) {
        this.crf.getParameters().getParameters(buffer);
    }

    @Override
    public double getParameter(int index) {
        return this.crf.getParameters().getParameter(index);
    }

    @Override
    public void setParameters(double[] params) {
        this.crf.getParameters().setParameters(params);
        this.crf.weightsValueChanged();
    }

    @Override
    public void setParameter(int index, double value) {
        this.crf.getParameters().setParameter(index, value);
        this.crf.weightsValueChanged();
    }

    public void cacheValueAndGradient() {
        int end;
        int start;
        SumLattice lattice;
        ArrayList<SumLattice> lattices = new ArrayList<SumLattice>();
        if (this.numThreads == 1) {
            for (int ii = 0; ii < this.data.size(); ++ii) {
                if (this.instancesWithConstraints.get(ii)) {
                    lattice = new SumLatticeDefault((Transducer)this.crf, (Sequence)((FeatureVectorSequence)((Instance)this.data.get(ii)).getData()), null, null, true);
                    lattices.add(lattice);
                    continue;
                }
                lattices.add(null);
            }
        } else {
            ArrayList<SumLatticeTask> tasks = new ArrayList<SumLatticeTask>();
            if (this.data.size() < this.numThreads) {
                this.numThreads = this.data.size();
            }
            int increment = this.data.size() / this.numThreads;
            start = 0;
            end = increment;
            for (int thread = 0; thread < this.numThreads; ++thread) {
                tasks.add(new SumLatticeTask(this.crf, this.data, this.instancesWithConstraints, start, end));
                start += increment;
                if (thread == this.numThreads - 2) {
                    end = this.data.size();
                    continue;
                }
                end += increment;
            }
            try {
                this.executor.invokeAll(tasks);
            }
            catch (InterruptedException ie) {
                ie.printStackTrace();
            }
            for (Callable callable : tasks) {
                lattices.addAll(((SumLatticeTask)callable).getLattices());
            }
            assert (lattices.size() == this.data.size()) : lattices.size() + " " + this.data.size();
        }
        System.err.println("Done computing lattices.");
        for (GEConstraint constraint : this.constraints) {
            constraint.zeroExpectations();
            constraint.computeExpectations(lattices);
        }
        System.err.println("Done computing expectations.");
        this.cachedValue = 0.0;
        for (GEConstraint constraint : this.constraints) {
            this.cachedValue += constraint.getValue();
        }
        this.cachedGradient.zero();
        if (this.numThreads == 1) {
            for (int ii = 0; ii < this.data.size(); ++ii) {
                if (!this.instancesWithConstraints.get(ii)) continue;
                lattice = lattices.get(ii);
                FeatureVectorSequence fvs = (FeatureVectorSequence)((Instance)this.data.get(ii)).getData();
                new GELattice(fvs, lattice.getGammas(), lattice.getXis(), this.crf, this.reverseTrans, this.reverseTransIndices, this.cachedGradient, this.constraints, false);
            }
        } else {
            ArrayList<GELatticeTask> tasks = new ArrayList<GELatticeTask>();
            if (this.data.size() < this.numThreads) {
                this.numThreads = this.data.size();
            }
            int increment = this.data.size() / this.numThreads;
            start = 0;
            end = increment;
            for (int thread = 0; thread < this.numThreads; ++thread) {
                ArrayList<GEConstraint> arrayList = new ArrayList<GEConstraint>();
                for (GEConstraint constraint : this.constraints) {
                    arrayList.add(constraint.copy());
                }
                tasks.add(new GELatticeTask(this.crf, this.data, lattices, arrayList, this.instancesWithConstraints, this.reverseTrans, this.reverseTransIndices, start, end));
                start += increment;
                if (thread == this.numThreads - 2) {
                    end = this.data.size();
                    continue;
                }
                end += increment;
            }
            try {
                this.executor.invokeAll(tasks);
            }
            catch (InterruptedException ie) {
                ie.printStackTrace();
            }
            for (Callable callable : tasks) {
                this.cachedGradient.plusEquals(((GELatticeTask)callable).getGradient(), 1.0);
            }
        }
        System.err.println("Done computing gradient.");
        this.cachedValue += this.crf.getParameters().gaussianPrior(this.gpv);
        this.cachedGradient.plusEqualsGaussianPriorGradient(this.crf.getParameters(), this.gpv);
        System.err.println("Done computing regularization.");
        if (this.weight != 1.0) {
            this.cachedValue *= this.weight;
        }
        System.err.println("GE Value = " + this.cachedValue);
    }

    public void setGaussianPriorVariance(double variance) {
        this.gpv = variance;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.crf.getWeightsValueChangeStamp() != this.cache) {
            this.cacheValueAndGradient();
            this.cache = this.crf.getWeightsValueChangeStamp();
        }
        this.cachedGradient.getParameters(buffer);
        if (this.weight != 1.0) {
            MatrixOps.timesEquals(buffer, this.weight);
        }
    }

    @Override
    public double getValue() {
        if (this.crf.getWeightsValueChangeStamp() != this.cache) {
            this.cacheValueAndGradient();
            this.cache = this.crf.getWeightsValueChangeStamp();
        }
        return this.cachedValue;
    }

    public void shutdown() {
        if (this.executor == null) {
            return;
        }
        this.executor.shutdown();
        try {
            this.executor.awaitTermination(30L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        assert (this.executor.shutdownNow().size() == 0) : "All tasks didn't finish";
    }
}

