/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.classification.model;

import com.carrotsearch.hppc.DoubleArrayList;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.prediction.IntPrediction;
import com.googlecode.clearnlp.classification.prediction.StringPrediction;
import com.googlecode.clearnlp.classification.vector.SparseFeatureVector;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.util.UTArray;
import com.googlecode.clearnlp.util.pair.Pair;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.regex.Pattern;

public class ONStringModel
extends StringModel {
    protected List<DoubleArrayList> d_weights;
    protected List<String> a_labels;
    protected List<DoubleArrayList> d_gs;
    protected double d_alpha;
    protected double d_rho;

    public ONStringModel(double alpha, double rho) {
        this.initModel();
        this.initAdaGrad(alpha, rho);
    }

    public ONStringModel(BufferedReader reader, double alpha, double rho) {
        this.load(reader);
        this.initAdaGrad(alpha, rho);
    }

    private void initModel() {
        this.n_labels = 0;
        this.n_features = 1;
        this.d_weights = new ArrayList<DoubleArrayList>();
        this.d_gs = new ArrayList<DoubleArrayList>();
        this.a_labels = new ArrayList<String>();
        this.m_labels = new ObjectIntOpenHashMap();
        this.m_features = new HashMap();
        this.i_solver = (byte)3;
        this.d_weights.add(this.getBlankDoubleArrayList(this.n_labels));
    }

    private void initAdaGrad(double alpha, double rho) {
        this.d_gs = new ArrayList<DoubleArrayList>(this.n_features);
        for (int i = 0; i < this.n_features; ++i) {
            this.d_gs.add(this.getBlankDoubleArrayList(this.n_labels));
        }
        this.d_alpha = alpha;
        this.d_rho = rho;
    }

    protected DoubleArrayList getBlankDoubleArrayList(int size) {
        DoubleArrayList list = new DoubleArrayList(size);
        for (int i = 0; i < size; ++i) {
            list.add(0.0);
        }
        return list;
    }

    @Override
    public void load(BufferedReader reader) {
        System.out.println("Loading model:");
        try {
            this.i_solver = Byte.parseByte(reader.readLine());
            this.loadLabels(reader);
            this.loadFeatures(reader);
            this.loadWeightVector(reader);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        System.out.println();
    }

    @Override
    protected void loadLabels(BufferedReader fin) throws IOException {
        this.n_labels = Integer.parseInt(fin.readLine());
        String[] labels = fin.readLine().split(" ");
        this.a_labels = UTArray.toList(labels);
        this.m_labels = new ObjectIntOpenHashMap();
        for (int i = 0; i < this.n_labels; ++i) {
            this.m_labels.put((Object)labels[i], i + 1);
        }
    }

    @Override
    protected void loadFeatures(BufferedReader fin) throws IOException {
        this.n_features = Integer.parseInt(fin.readLine());
        int typeSize = Integer.parseInt(fin.readLine());
        this.m_features = new HashMap();
        Pattern P_DELIM = Pattern.compile(" ");
        for (int i = 0; i < typeSize; ++i) {
            ObjectIntOpenHashMap map = new ObjectIntOpenHashMap();
            String type = fin.readLine();
            int valueSize = Integer.parseInt(fin.readLine());
            for (int j = 0; j < valueSize; ++j) {
                String[] tmp = P_DELIM.split(fin.readLine());
                map.put((Object)tmp[0], Integer.parseInt(tmp[1]));
            }
            this.m_features.put(type, map);
        }
    }

    @Override
    protected void loadWeightVector(BufferedReader fin) throws Exception {
        int[] buffer = new int[128];
        this.d_weights = new ArrayList<DoubleArrayList>(this.n_features);
        Integer.parseInt(fin.readLine());
        for (int i = 0; i < this.n_features; ++i) {
            if (i % 100000 == 0) {
                System.out.print(".");
            }
            DoubleArrayList weight = new DoubleArrayList(this.n_labels);
            for (int j = 0; j < this.n_labels; ++j) {
                int ch;
                int b = 0;
                while ((ch = fin.read()) != 32) {
                    buffer[b++] = ch;
                }
                weight.add(Double.parseDouble(new String(buffer, 0, b)));
            }
            this.d_weights.add(weight);
        }
        fin.readLine();
    }

    @Override
    public void save(PrintStream fout) {
        System.out.println("Saving model:");
        try {
            fout.println(this.i_solver);
            this.saveLabels(fout);
            this.saveFeatures(fout);
            this.saveWeightVector(fout);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        System.out.println();
    }

    @Override
    protected void saveLabels(PrintStream fout) {
        fout.println(this.n_labels);
        fout.println(UTArray.join(this.a_labels, " "));
    }

    @Override
    protected void saveFeatures(PrintStream fout) {
        fout.println(this.n_features);
        fout.println(this.m_features.size());
        for (String type : this.m_features.keySet()) {
            ObjectIntOpenHashMap map = (ObjectIntOpenHashMap)this.m_features.get(type);
            fout.println(type);
            fout.println(map.size());
            for (ObjectCursor cur : map.keys()) {
                String value = (String)cur.value;
                StringBuilder build = new StringBuilder();
                build.append(value);
                build.append(" ");
                build.append(map.get((Object)value));
                fout.println(build.toString());
            }
        }
    }

    @Override
    protected void saveWeightVector(PrintStream fout) {
        fout.println(this.n_labels * this.n_features);
        for (int i = 0; i < this.n_features; ++i) {
            if (i % 100000 == 0) {
                System.out.print(".");
            }
            DoubleArrayList weight = this.d_weights.get(i);
            StringBuilder build = new StringBuilder();
            for (int j = 0; j < this.n_labels; ++j) {
                build.append(weight.get(j));
                build.append(' ');
            }
            fout.print(build.toString());
        }
        fout.println();
    }

    @Override
    public void addLabel(String label) {
        if (!this.m_labels.containsKey((Object)label)) {
            this.a_labels.add(label);
            this.m_labels.put((Object)label, ++this.n_labels);
            this.addLabelAux();
        }
    }

    private void addLabelAux() {
        for (int i = 0; i < this.n_features; ++i) {
            this.d_weights.get(i).add(0.0);
            this.d_gs.get(i).add(0.0);
        }
    }

    @Override
    public void addFeature(String type, String value) {
        ObjectIntOpenHashMap map = (ObjectIntOpenHashMap)this.m_features.get(type);
        if (map == null) {
            map = new ObjectIntOpenHashMap();
            this.m_features.put(type, map);
        }
        if (!map.containsKey((Object)value)) {
            map.put((Object)value, this.n_features++);
            this.addFeatureAux();
        }
    }

    private void addFeatureAux() {
        this.d_weights.add(this.getBlankDoubleArrayList(this.n_labels));
        this.d_gs.add(this.getBlankDoubleArrayList(this.n_labels));
    }

    public void addFeatures(StringFeatureVector vector) {
        int size = vector.size();
        for (int i = 0; i < size; ++i) {
            this.addFeature(vector.getType(i), vector.getValue(i));
        }
    }

    @Override
    public double[] getScores(SparseFeatureVector x) {
        double[] scores = this.d_weights.get(0).toArray();
        int size = x.size();
        double value = 1.0;
        for (int i = 0; i < size; ++i) {
            int index = x.getIndex(i);
            if (x.hasWeight()) {
                value = x.getWeight(i);
            }
            if (!this.isRange(index)) continue;
            DoubleArrayList weight = this.d_weights.get(index);
            for (int label = 0; label < this.n_labels; ++label) {
                if (x.hasWeight()) {
                    int n = label;
                    scores[n] = scores[n] + weight.get(label) * value;
                    continue;
                }
                int n = label;
                scores[n] = scores[n] + weight.get(label);
            }
        }
        return scores;
    }

    @Override
    public List<StringPrediction> getPredictions(SparseFeatureVector x) {
        ArrayList<StringPrediction> list = new ArrayList<StringPrediction>(this.n_labels);
        double[] scores = this.getScores(x);
        for (int i = 0; i < this.n_labels; ++i) {
            list.add(new StringPrediction(this.a_labels.get(i), scores[i]));
        }
        return list;
    }

    public void updateWeights(List<Pair<String, StringFeatureVector>> instances) {
        for (Pair<String, StringFeatureVector> p : instances) {
            this.updateWeights((String)p.o1, (StringFeatureVector)p.o2);
        }
    }

    public void updateWeights(String label, StringFeatureVector vector) {
        this.addLabel(label);
        this.addFeatures(vector);
        SparseFeatureVector x = this.toSparseFeatureVector(vector);
        int y = this.getLabelIndex(label);
        double[] scores = this.getScores(x);
        int n = y;
        scores[n] = scores[n] - 1.0;
        IntPrediction max = new IntPrediction(0, scores[0]);
        for (int i = 1; i < this.n_labels; ++i) {
            if (!(max.score < scores[i])) continue;
            max.set(i, scores[i]);
        }
        if (max.label != y) {
            this.updateCounts(y, max.label, x);
            this.updateWeights(y, max.label, x);
        }
    }

    private void updateCounts(int yp, int yn, SparseFeatureVector x) {
        int len = x.size();
        if (x.hasWeight()) {
            for (int i = 0; i < len; ++i) {
                DoubleArrayList g = this.d_gs.get(x.getIndex(i));
                double d = x.getWeight(i) * x.getWeight(i);
                this.add(g, yp, d);
                this.add(g, yn, d);
            }
        } else {
            for (int i = 0; i < len; ++i) {
                DoubleArrayList g = this.d_gs.get(x.getIndex(i));
                this.add(g, yp, 1.0);
                this.add(g, yn, 1.0);
            }
        }
    }

    private void updateWeights(int yp, int yn, SparseFeatureVector x) {
        int len = x.size();
        if (x.hasWeight()) {
            for (int i = 0; i < len; ++i) {
                int xi = x.getIndex(i);
                double vi = x.getWeight(i);
                DoubleArrayList w = this.d_weights.get(xi);
                this.add(w, yp, vi * this.getUpdate(yp, xi));
                this.add(w, yn, -vi * this.getUpdate(yn, xi));
            }
        } else {
            for (int i = 0; i < len; ++i) {
                int xi = x.getIndex(i);
                DoubleArrayList w = this.d_weights.get(xi);
                this.add(w, yp, this.getUpdate(yp, xi));
                this.add(w, yn, -this.getUpdate(yn, xi));
            }
        }
    }

    private void add(DoubleArrayList list, int index, double value) {
        list.set(index, list.get(index) + value);
    }

    private double getUpdate(int y, int x) {
        return this.d_alpha / (this.d_rho + Math.sqrt(this.d_gs.get(x).get(y)));
    }
}

