/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.component.pos;

import com.googlecode.clearnlp.classification.algorithm.AbstractAlgorithm;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.prediction.StringPrediction;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.engine.EngineProcess;
import com.googlecode.clearnlp.feature.xml.FtrToken;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.pos.POSState;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.UTString;
import com.googlecode.clearnlp.util.map.Prob2DMap;
import com.googlecode.clearnlp.util.pair.ObjectDoublePair;
import com.googlecode.clearnlp.util.pair.Pair;
import com.googlecode.clearnlp.util.pair.StringDoublePair;
import com.googlecode.clearnlp.util.triple.Triple;
import java.io.BufferedReader;
import java.io.PrintStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

public class CPOSBackTagger
extends AbstractStatisticalComponent {
    protected final String ENTRY_CONFIGURATION = "pos_CONFIGURATION";
    protected final String ENTRY_FEATURE = "pos_FEATURE";
    protected final String ENTRY_LEXICA = "pos_LEXICA";
    protected final String ENTRY_MODEL = "pos_MODEL";
    protected final int LEXICA_LOWER_SIMPLIFIED_FORMS = 0;
    protected final int LEXICA_AMBIGUITY_CLASSES = 1;
    protected Set<String> s_lsfs;
    protected Prob2DMap p_ambi;
    protected Map<String, String> m_ambi;
    protected String[] g_tags;
    protected int i_input;
    protected double d_score;
    protected double d_margin;

    public CPOSBackTagger() {
    }

    public CPOSBackTagger(JointFtrXml[] xmls, Set<String> sLsfs) {
        super(xmls);
        this.s_lsfs = sLsfs;
        this.p_ambi = new Prob2DMap();
    }

    public CPOSBackTagger(JointFtrXml[] xmls, StringTrainSpace[] spaces, Object[] lexica, double margin) {
        super(xmls, spaces, lexica);
        this.d_margin = margin;
    }

    public CPOSBackTagger(JointFtrXml[] xmls, StringModel[] models, Object[] lexica, double margin) {
        super(xmls, models, lexica);
        this.d_margin = margin;
    }

    public CPOSBackTagger(JointFtrXml[] xmls, StringTrainSpace[] spaces, StringModel[] models, Object[] lexica, double margin) {
        super(xmls, spaces, models, lexica);
        this.d_margin = margin;
    }

    public CPOSBackTagger(ZipInputStream in) {
        super(in);
    }

    @Override
    protected void initLexia(Object[] lexica) {
        this.s_lsfs = (Set)lexica[0];
        this.m_ambi = (Map)lexica[1];
    }

    @Override
    public void loadModels(ZipInputStream zin) {
        int fLen = "pos_FEATURE".length();
        int mLen = "pos_MODEL".length();
        this.f_xmls = new JointFtrXml[1];
        this.s_models = null;
        try {
            ZipEntry zEntry;
            while ((zEntry = zin.getNextEntry()) != null) {
                String entry = zEntry.getName();
                if (entry.equals("pos_CONFIGURATION")) {
                    this.loadDefaultConfiguration(zin);
                    continue;
                }
                if (entry.startsWith("pos_FEATURE")) {
                    this.loadFeatureTemplates(zin, Integer.parseInt(entry.substring(fLen)));
                    continue;
                }
                if (entry.startsWith("pos_MODEL")) {
                    this.loadStatisticalModels(zin, Integer.parseInt(entry.substring(mLen)));
                    continue;
                }
                if (!entry.equals("pos_LEXICA")) continue;
                this.loadLexica(zin);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected void loadLexica(ZipInputStream zin) throws Exception {
        BufferedReader fin = UTInput.createBufferedReader(zin);
        System.out.println("Loading lexica.");
        this.s_lsfs = UTInput.getStringSet(fin);
        this.m_ambi = UTInput.getStringMap(fin, " ");
    }

    @Override
    public void saveModels(ZipOutputStream zout) {
        try {
            this.saveDefaultConfiguration(zout, "pos_CONFIGURATION");
            this.saveFeatureTemplates(zout, "pos_FEATURE");
            this.saveLexica(zout);
            this.saveStatisticalModels(zout, "pos_MODEL");
            zout.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected void saveLexica(ZipOutputStream zout) throws Exception {
        zout.putNextEntry(new ZipEntry("pos_LEXICA"));
        PrintStream fout = UTOutput.createPrintBufferedStream(zout);
        System.out.println("Saving lexica.");
        UTOutput.printSet(fout, this.s_lsfs);
        fout.flush();
        UTOutput.printMap(fout, this.m_ambi, " ");
        fout.flush();
        zout.closeEntry();
    }

    @Override
    public Object[] getLexica() {
        Object[] lexica = new Object[]{this.s_lsfs, this.i_flag == 0 ? this.getAmbiguityClasses() : this.m_ambi};
        return lexica;
    }

    public Set<String> getLowerSimplifiedForms() {
        return this.s_lsfs;
    }

    public void clearLowerSimplifiedForms() {
        this.s_lsfs.clear();
    }

    private Map<String, String> getAmbiguityClasses() {
        double threshold = this.f_xmls[0].getAmbiguityClassThreshold();
        HashMap<String, String> mAmbi = new HashMap<String, String>();
        for (String key : this.p_ambi.keySet()) {
            StringBuilder build = new StringBuilder();
            Object[] ps = this.p_ambi.getProb1D(key);
            Arrays.sort(ps);
            for (Object p : ps) {
                if (((StringDoublePair)p).d <= threshold) break;
                build.append("_");
                build.append(((StringDoublePair)p).s);
            }
            if (build.length() <= 0) continue;
            mAmbi.put(key, build.substring(1));
        }
        return mAmbi;
    }

    @Override
    public void countAccuracy(int[] counts) {
        int correct = 0;
        for (int i = 1; i < this.t_size; ++i) {
            if (!this.d_tree.get((int)i).pos.equals(this.g_tags[i])) continue;
            ++correct;
        }
        counts[0] = counts[0] + (this.t_size - 1);
        counts[1] = counts[1] + correct;
    }

    @Override
    public void process(DEPTree tree) {
        this.init(tree);
        this.processAux();
    }

    protected void init(DEPTree tree) {
        this.d_tree = tree;
        this.t_size = tree.size();
        this.d_score = 0.0;
        this.i_input = 1;
        if (this.i_flag != 2) {
            this.g_tags = tree.getPOSTags();
            tree.clearPOSTags();
        }
        EngineProcess.normalizeForms(tree);
    }

    protected void processAux() {
        if (this.i_flag == 0) {
            this.addLexica();
        } else {
            List<Pair<String, StringFeatureVector>> insts = this.tag();
            if (insts != null) {
                for (Pair<String, StringFeatureVector> inst : insts) {
                    this.s_spaces[0].addInstance((String)inst.o1, (StringFeatureVector)inst.o2);
                }
            }
        }
    }

    protected void addLexica() {
        for (int i = 1; i < this.t_size; ++i) {
            DEPNode node = this.d_tree.get(i);
            if (!this.s_lsfs.contains(node.lowerSimplifiedForm)) continue;
            this.p_ambi.add(node.simplifiedForm, this.g_tags[i]);
        }
    }

    protected List<Pair<String, StringFeatureVector>> tag() {
        return this.i_flag == 1 ? (List)this.tagMain().o2 : this.tagBranches();
    }

    protected Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>> tagMain() {
        ArrayList<Pair<String, StringFeatureVector>> insts = new ArrayList<Pair<String, StringFeatureVector>>();
        ArrayDeque<POSState> states = new ArrayDeque<POSState>();
        while (this.i_input < this.t_size) {
            this.tagAux(this.getLabel(insts, states));
        }
        return new Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>(this.d_tree.getPOSTags(), insts, states);
    }

    private void tagAux(StringPrediction p) {
        this.d_tree.get((int)this.i_input).pos = p.label;
        this.d_score += p.score;
        ++this.i_input;
    }

    protected List<Pair<String, StringFeatureVector>> tagBranches() {
        ArrayList<ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>>> list = new ArrayList<ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>>>();
        Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>> t0 = this.tagMain();
        list.add(new ObjectDoublePair(t0, this.d_score));
        for (POSState state : (Deque)t0.o3) {
            this.reset(state);
            list.add(new ObjectDoublePair(this.tagMain(), this.d_score));
        }
        if (this.i_flag == 2 || this.i_flag == 4) {
            Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>> max = this.getMax(list);
            this.d_tree.resetPOSTags((String[])max.o1);
            return null;
        }
        this.setGoldScores(list);
        Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>> max = this.getMax(list);
        ArrayList<Pair<String, StringFeatureVector>> insts = new ArrayList<Pair<String, StringFeatureVector>>((Collection)t0.o2);
        insts.addAll((Collection)max.o2);
        return insts;
    }

    private void reset(POSState state) {
        this.i_input = state.input;
        this.d_score = state.score;
        this.tagAux(state.label);
        for (int i = this.i_input + 1; i < this.t_size; ++i) {
            this.d_tree.get((int)i).pos = null;
        }
    }

    private Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>> getMax(List<ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>>> list) {
        ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>> max = list.get(0);
        int size = list.size();
        for (int i = 1; i < size; ++i) {
            ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>> p = list.get(i);
            if (!(max.d < p.d)) continue;
            max = p;
        }
        return (Triple)max.o;
    }

    private void setGoldScores(List<ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>>> list) {
        for (ObjectDoublePair<Triple<String[], List<Pair<String, StringFeatureVector>>, Deque<POSState>>> p : list) {
            String[] tags = (String[])((Triple)p.o).o1;
            int c = 0;
            for (int i = 1; i < this.t_size; ++i) {
                if (!this.g_tags[i].equals(tags[i])) continue;
                ++c;
            }
            p.d = c;
        }
    }

    private StringPrediction getLabel(List<Pair<String, StringFeatureVector>> insts, Deque<POSState> states) {
        StringFeatureVector vector = this.getFeatureVector(this.f_xmls[0]);
        StringPrediction p = null;
        if (this.i_flag == 1) {
            p = this.getGoldLabel();
            if (vector.size() > 0) {
                insts.add(new Pair<String, StringFeatureVector>(p.label, vector));
            }
        } else if (this.i_flag == 2 || this.i_flag == 4) {
            p = this.getAutoLabel(vector, states);
        } else if (this.i_flag == 3) {
            p = this.getAutoLabel(vector, states);
            if (vector.size() > 0) {
                insts.add(new Pair<String, StringFeatureVector>(this.getGoldLabel().label, vector));
            }
        }
        return p;
    }

    private StringPrediction getGoldLabel() {
        return new StringPrediction(this.g_tags[this.i_input], 1.0);
    }

    private StringPrediction getAutoLabel(StringFeatureVector vector, Deque<POSState> states) {
        List<StringPrediction> ps = this.s_models[0].predictAll(vector);
        AbstractAlgorithm.normalize(ps);
        StringPrediction fst = ps.get(0);
        StringPrediction snd = ps.get(1);
        if (fst.score - snd.score < this.d_margin) {
            states.add(new POSState(this.i_input, this.d_score, snd));
        }
        return fst;
    }

    @Override
    protected String getField(FtrToken token) {
        DEPNode node = this.getNodeInput(token);
        if (node == null) {
            return null;
        }
        if (token.isField("sf")) {
            return this.s_lsfs.contains(node.lowerSimplifiedForm) ? node.simplifiedForm : null;
        }
        if (token.isField("lsf")) {
            return this.s_lsfs.contains(node.lowerSimplifiedForm) ? node.lowerSimplifiedForm : null;
        }
        if (token.isField("p")) {
            return node.pos;
        }
        if (token.isField("a")) {
            return this.m_ambi.get(node.simplifiedForm);
        }
        Matcher m = JointFtrXml.P_BOOLEAN.matcher(token.field);
        if (m.find()) {
            int field = Integer.parseInt(m.group(1));
            switch (field) {
                case 0: {
                    return UTString.isAllUpperCase(node.simplifiedForm) ? token.field : null;
                }
                case 1: {
                    return UTString.isAllLowerCase(node.simplifiedForm) ? token.field : null;
                }
                case 2: {
                    return UTString.beginsWithUpperCase(node.simplifiedForm) ? token.field : null;
                }
                case 3: {
                    return UTString.getNumOfCapitalsNotAtBeginning(node.simplifiedForm) == 1 ? token.field : null;
                }
                case 4: {
                    return UTString.getNumOfCapitalsNotAtBeginning(node.simplifiedForm) > 1 ? token.field : null;
                }
                case 5: {
                    return node.simplifiedForm.contains(".") ? token.field : null;
                }
                case 6: {
                    return UTString.containsDigit(node.simplifiedForm) ? token.field : null;
                }
                case 7: {
                    return node.simplifiedForm.contains("-") ? token.field : null;
                }
                case 8: {
                    return this.i_input == this.t_size - 1 ? token.field : null;
                }
                case 9: {
                    return this.i_input == 1 ? token.field : null;
                }
            }
            throw new IllegalArgumentException("Unsupported feature: " + field);
        }
        m = JointFtrXml.P_FEAT.matcher(token.field);
        if (m.find()) {
            return node.getFeat(m.group(1));
        }
        m = JointFtrXml.P_PREFIX.matcher(token.field);
        if (m.find()) {
            int len;
            int n = Integer.parseInt(m.group(1));
            return n <= (len = node.lowerSimplifiedForm.length()) ? node.lowerSimplifiedForm.substring(0, n) : null;
        }
        m = JointFtrXml.P_SUFFIX.matcher(token.field);
        if (m.find()) {
            int len;
            int n = Integer.parseInt(m.group(1));
            return n <= (len = node.lowerSimplifiedForm.length()) ? node.lowerSimplifiedForm.substring(len - n, len) : null;
        }
        return null;
    }

    @Override
    protected String[] getFields(FtrToken token) {
        DEPNode node = this.getNodeInput(token);
        if (node == null) {
            return null;
        }
        Matcher m = JointFtrXml.P_PREFIX.matcher(token.field);
        if (m.find()) {
            String[] fields = UTString.getPrefixes(node.lowerSimplifiedForm, Integer.parseInt(m.group(1)));
            return fields.length == 0 ? null : fields;
        }
        m = JointFtrXml.P_SUFFIX.matcher(token.field);
        if (m.find()) {
            String[] fields = UTString.getSuffixes(node.lowerSimplifiedForm, Integer.parseInt(m.group(1)));
            return fields.length == 0 ? null : fields;
        }
        return null;
    }

    protected DEPNode getNodeInput(FtrToken token) {
        int index = this.i_input + token.offset;
        return 0 < index && index < this.t_size ? this.d_tree.get(index) : null;
    }
}

