/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.temporal.eval;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.CliFactory;
import com.lexicalscope.jewel.cli.Option;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
import org.apache.ctakes.temporal.ae.BackwardsTimeAnnotator;
import org.apache.ctakes.temporal.ae.CRFTimeAnnotator;
import org.apache.ctakes.temporal.ae.ConstituencyBasedTimeAnnotator;
import org.apache.ctakes.temporal.ae.MetaTimeAnnotator;
import org.apache.ctakes.temporal.ae.TimeAnnotator;
import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
import org.apache.ctakes.temporal.eval.EvaluationOfAnnotationSpans_ImplBase;
import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
import org.apache.ctakes.temporal.eval.I2B2Data;
import org.apache.ctakes.temporal.eval.THYMEData;
import org.apache.ctakes.typesystem.type.textsem.TimeMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.CleartkSequenceAnnotator;
import org.cleartk.ml.Instance;
import org.cleartk.ml.crfsuite.CrfSuiteStringOutcomeDataWriter;
import org.cleartk.ml.feature.transform.InstanceDataWriter;
import org.cleartk.ml.feature.transform.InstanceStream;
import org.cleartk.ml.jar.JarClassifierBuilder;
import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;

public class EvaluationOfTimeSpans
extends EvaluationOfAnnotationSpans_ImplBase {
    private Class<? extends JCasAnnotator_ImplBase> annotatorClass;
    private String[] trainingArguments;
    private float featureSelectionThreshold;
    private float smoteNeighborNumber;
    private boolean skipTrain = false;

    public static void main(String[] args) throws Exception {
        Options options = (Options)CliFactory.parseArguments(Options.class, (String[])args);
        List<Integer> trainItems = null;
        List<Integer> devItems = null;
        List<Integer> testItems = null;
        List<Integer> patientSets = options.getPatients().getList();
        if (options.getXMLFormat() == Evaluation_ImplBase.XMLFormat.I2B2) {
            trainItems = I2B2Data.getTrainPatientSets(options.getXMLDirectory());
            devItems = I2B2Data.getDevPatientSets(options.getXMLDirectory());
            testItems = I2B2Data.getTestPatientSets(options.getXMLDirectory());
        } else {
            trainItems = THYMEData.getPatientSets(patientSets, options.getTrainRemainders().getList());
            devItems = THYMEData.getPatientSets(patientSets, options.getDevRemainders().getList());
            testItems = THYMEData.getPatientSets(patientSets, options.getTestRemainders().getList());
        }
        ArrayList<Integer> allTrain = new ArrayList<Integer>(trainItems);
        ArrayList<Integer> allTest = null;
        if (options.getTest()) {
            allTrain.addAll(devItems);
            allTest = new ArrayList<Integer>(testItems);
        } else {
            allTest = new ArrayList<Integer>(devItems);
        }
        ArrayList annotatorClasses = Lists.newArrayList();
        if (options.getRunBackwards()) {
            annotatorClasses.add(BackwardsTimeAnnotator.class);
        }
        if (options.getRunForwards()) {
            annotatorClasses.add(TimeAnnotator.class);
        }
        if (options.getRunParserBased()) {
            annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
        }
        if (options.getRunCrfBased()) {
            annotatorClasses.add(CRFTimeAnnotator.class);
        }
        if (annotatorClasses.size() == 0) {
            annotatorClasses.add(BackwardsTimeAnnotator.class);
            annotatorClasses.add(TimeAnnotator.class);
            annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
            annotatorClasses.add(CRFTimeAnnotator.class);
        }
        HashMap annotatorTrainingArguments = Maps.newHashMap();
        annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[]{"-c", "0.1"});
        annotatorTrainingArguments.put(TimeAnnotator.class, new String[]{"-c", "0.1"});
        annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[]{"-c", "0.3"});
        annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[]{"-p", "c2=0.3"});
        final HashMap annotatorStats = Maps.newHashMap();
        for (Class annotatorClass : annotatorClasses) {
            EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(new File("target/eval/time-spans"), options.getRawTextDirectory(), options.getXMLDirectory(), options.getXMLFormat(), options.getSubcorpus(), options.getXMIDirectory(), options.getTreebankDirectory(), options.getFeatureSelectionThreshold(), options.getSMOTENeighborNumber(), annotatorClass, options.getPrintOverlappingSpans(), (String[])annotatorTrainingArguments.get(annotatorClass));
            evaluation.prepareXMIsFor(patientSets);
            evaluation.setSkipTrain(options.getSkipTrain());
            evaluation.printErrors = options.getPrintErrors();
            if (options.getI2B2Output() != null) {
                evaluation.setI2B2Output(options.getI2B2Output() + "/" + annotatorClass.getSimpleName());
            }
            String name = String.format("%s.errors", annotatorClass.getSimpleName());
            evaluation.setLogging(Level.FINE, new File("target/eval", name));
            AnnotationStatistics stats = (AnnotationStatistics)evaluation.trainAndTest(allTrain, allTest);
            annotatorStats.put(annotatorClass, stats);
        }
        Ordering byF1 = Ordering.natural().onResultOf((Function)new Function<Class<? extends JCasAnnotator_ImplBase>, Double>(){

            public Double apply(Class<? extends JCasAnnotator_ImplBase> annotatorClass) {
                return ((AnnotationStatistics)annotatorStats.get(annotatorClass)).f1();
            }
        });
        for (Class annotatorClass : byF1.sortedCopy((Iterable)annotatorClasses)) {
            System.err.printf("===== %s =====\n", annotatorClass.getSimpleName());
            System.err.println(annotatorStats.get(annotatorClass));
        }
    }

    public EvaluationOfTimeSpans(File baseDirectory, File rawTextDirectory, File xmlDirectory, Evaluation_ImplBase.XMLFormat xmlFormat, Evaluation_ImplBase.Subcorpus subcorpus, File xmiDirectory, File treebankDirectory, float featureSelectionThreshold, float numOfSmoteNeighbors, Class<? extends JCasAnnotator_ImplBase> annotatorClass, boolean printOverlapping, String[] trainingArguments) {
        super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, subcorpus, xmiDirectory, treebankDirectory, TimeMention.class);
        this.annotatorClass = annotatorClass;
        this.featureSelectionThreshold = featureSelectionThreshold;
        this.trainingArguments = trainingArguments;
        this.printOverlapping = printOverlapping;
        this.smoteNeighborNumber = numOfSmoteNeighbors;
    }

    public void setSkipTrain(boolean val) {
        this.skipTrain = val;
    }

    @Override
    public void train(CollectionReader reader, File directory) throws Exception {
        if (!this.skipTrain) {
            super.train(reader, directory);
        }
    }

    @Override
    protected AnalysisEngineDescription getDataWriterDescription(File directory) throws ResourceInitializationException {
        if (MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            return MetaTimeAnnotator.getDataWriterDescription(CrfSuiteStringOutcomeDataWriter.class, directory);
        }
        if (CleartkAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            if ("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())) {
                Class dataWriterClass = this.featureSelectionThreshold > 0.0f ? InstanceDataWriter.class : LibLinearStringOutcomeDataWriter.class;
                return TimeAnnotator.createDataWriterDescription(dataWriterClass, this.getModelDirectory(directory), this.featureSelectionThreshold, this.smoteNeighborNumber);
            }
            return AnalysisEngineFactory.createEngineDescription(this.annotatorClass, (Object[])new Object[]{"isTraining", true, "dataWriterClassName", LibLinearStringOutcomeDataWriter.class, "outputDirectory", this.getModelDirectory(directory)});
        }
        if (CleartkSequenceAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            return AnalysisEngineFactory.createEngineDescription(this.annotatorClass, (Object[])new Object[]{"isTraining", true, "dataWriterClassName", CrfSuiteStringOutcomeDataWriter.class, "outputDirectory", this.getModelDirectory(directory)});
        }
        throw new ResourceInitializationException("Annotator class was not recognized as an acceptable class!", new Object[0]);
    }

    @Override
    protected void trainAndPackage(File directory) throws Exception {
        if (this.featureSelectionThreshold > 0.0f && "org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())) {
            Iterable instances = InstanceStream.loadFromDirectory((File)this.getModelDirectory(directory));
            FeatureSelection<String> featureSelection = TimeAnnotator.createFeatureSelection(this.featureSelectionThreshold);
            featureSelection.train(instances);
            featureSelection.save(TimeAnnotator.createFeatureSelectionURI(this.getModelDirectory(directory)));
            LibLinearStringOutcomeDataWriter dataWriter = new LibLinearStringOutcomeDataWriter(this.getModelDirectory(directory));
            for (Instance instance : instances) {
                dataWriter.write(featureSelection.transform((Instance<String>)instance));
            }
            dataWriter.finish();
        }
        JarClassifierBuilder.trainAndPackage((File)this.getModelDirectory(directory), (String[])this.trainingArguments);
    }

    @Override
    protected AnalysisEngineDescription getAnnotatorDescription(File directory) throws ResourceInitializationException {
        if (MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)) {
            return MetaTimeAnnotator.getAnnotatorDescription(directory);
        }
        if ("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())) {
            return TimeAnnotator.createAnnotatorDescription(this.getModelDirectory(directory));
        }
        return AnalysisEngineFactory.createEngineDescription(this.annotatorClass, (Object[])new Object[]{"isTraining", false, "classifierJarPath", new File(this.getModelDirectory(directory), "model.jar")});
    }

    @Override
    protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas, Segment segment) {
        return EvaluationOfTimeSpans.selectExact(jCas, TimeMention.class, segment);
    }

    @Override
    protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas, Segment segment) {
        return EvaluationOfTimeSpans.selectExact(jCas, TimeMention.class, segment);
    }

    private File getModelDirectory(File directory) {
        return new File(directory, this.annotatorClass.getSimpleName());
    }

    static interface Options
    extends Evaluation_ImplBase.Options {
        @Option(longName={"featureSelectionThreshold"}, defaultValue={"1"})
        public float getFeatureSelectionThreshold();

        @Option(longName={"SMOTENeighborNumber"}, defaultValue={"0"})
        public float getSMOTENeighborNumber();

        @Option(shortName={"b"})
        public boolean getRunBackwards();

        @Option(shortName={"f"})
        public boolean getRunForwards();

        @Option(shortName={"p"})
        public boolean getRunParserBased();

        @Option(shortName={"c"})
        public boolean getRunCrfBased();

        @Override
        @Option
        public boolean getSkipTrain();
    }
}

