/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.math.MutableInnerProductModule;
import breeze.optimize.CachedDiffFunction;
import breeze.optimize.FirstOrderMinimizer;
import breeze.optimize.LBFGS;
import breeze.optimize.StochasticDiffFunction;
import java.io.IOException;
import java.io.Serializable;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.feature.InstanceBlock$;
import org.apache.spark.ml.feature.StandardScalerModel$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.AFTBlockAggregator;
import org.apache.spark.ml.optim.loss.DifferentiableRegularization;
import org.apache.spark.ml.optim.loss.RDDLossFunction;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasAggregationDepth;
import org.apache.spark.ml.param.shared.HasFitIntercept;
import org.apache.spark.ml.param.shared.HasMaxBlockSizeInMB;
import org.apache.spark.ml.param.shared.HasMaxIter;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.regression.AFTSurvivalRegression$;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.ml.regression.AFTSurvivalRegressionParams;
import org.apache.spark.ml.regression.Regressor;
import org.apache.spark.ml.stat.Summarizer$;
import org.apache.spark.ml.stat.SummarizerBuffer;
import org.apache.spark.ml.util.DatasetUtils$;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.Iterator;
import scala.collection.SeqOps;
import scala.collection.immutable.Seq;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.ArrayBuilder$;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Statics;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0005\tEb\u0001B\f\u0019\u0001\rB\u0001\"\u0011\u0001\u0003\u0006\u0004%\tE\u0011\u0005\t3\u0002\u0011\t\u0011)A\u0005\u0007\")1\f\u0001C\u00019\")1\f\u0001C\u0001A\")!\r\u0001C\u0001G\")\u0001\u000e\u0001C\u0001S\")1\u000f\u0001C\u0001i\")q\u000f\u0001C\u0001q\")a\u0010\u0001C\u0001\u007f\"9\u00111\u0002\u0001\u0005\u0002\u00055\u0001bBA\n\u0001\u0011\u0005\u0011Q\u0003\u0005\b\u0003?\u0001A\u0011AA\u0011\u0011\u001d\tY\u0003\u0001C)\u0003[Aq!a\u0016\u0001\t\u0013\tI\u0006C\u0004\u0002,\u0002!\t%!,\t\u000f\u0005\u0005\u0007\u0001\"\u0011\u0002D\"A\u0011q\u001b\u0001\u0005Bq\tInB\u0004\u0002pbA\t!!=\u0007\r]A\u0002\u0012AAz\u0011\u0019Y6\u0003\"\u0001\u0003\u0012!9!1C\n\u0005B\tU\u0001\"\u0003B\u000f'\u0005\u0005I\u0011\u0002B\u0010\u0005U\te\tV*veZLg/\u00197SK\u001e\u0014Xm]:j_:T!!\u0007\u000e\u0002\u0015I,wM]3tg&|gN\u0003\u0002\u001c9\u0005\u0011Q\u000e\u001c\u0006\u0003;y\tQa\u001d9be.T!a\b\u0011\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005\t\u0013aA8sO\u000e\u00011#\u0002\u0001%eUZ\u0004#B\u0013'Q9zS\"\u0001\r\n\u0005\u001dB\"!\u0003*fOJ,7o]8s!\tIC&D\u0001+\u0015\tY#$\u0001\u0004mS:\fGnZ\u0005\u0003[)\u0012aAV3di>\u0014\bCA\u0013\u0001!\t)\u0003'\u0003\u000221\tQ\u0012I\u0012+TkJ4\u0018N^1m%\u0016<'/Z:tS>tWj\u001c3fYB\u0011QeM\u0005\u0003ia\u00111$\u0011$U'V\u0014h/\u001b<bYJ+wM]3tg&|g\u000eU1sC6\u001c\bC\u0001\u001c:\u001b\u00059$B\u0001\u001d\u001b\u0003\u0011)H/\u001b7\n\u0005i:$!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7o\u0016:ji\u0006\u0014G.\u001a\t\u0003y}j\u0011!\u0010\u0006\u0003}q\t\u0001\"\u001b8uKJt\u0017\r\\\u0005\u0003\u0001v\u0012q\u0001T8hO&tw-A\u0002vS\u0012,\u0012a\u0011\t\u0003\t6s!!R&\u0011\u0005\u0019KU\"A$\u000b\u0005!\u0013\u0013A\u0002\u001fs_>$hHC\u0001K\u0003\u0015\u00198-\u00197b\u0013\ta\u0015*\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u001d>\u0013aa\u0015;sS:<'B\u0001'JQ\r\t\u0011k\u0016\t\u0003%Vk\u0011a\u0015\u0006\u0003)r\t!\"\u00198o_R\fG/[8o\u0013\t16KA\u0003TS:\u001cW-I\u0001Y\u0003\u0015\tdF\u000e\u00181\u0003\u0011)\u0018\u000e\u001a\u0011)\u0007\t\tv+\u0001\u0004=S:LGO\u0010\u000b\u0003]uCQ!Q\u0002A\u0002\rC3!X)XQ\r\u0019\u0011k\u0016\u000b\u0002]!\u001aA!U,\u0002\u0019M,GoQ3og>\u00148i\u001c7\u0015\u0005\u0011,W\"\u0001\u0001\t\u000b\u0019,\u0001\u0019A\"\u0002\u000bY\fG.^3)\u0007\u0015\tv+\u0001\rtKR\fV/\u00198uS2,\u0007K]8cC\nLG.\u001b;jKN$\"\u0001\u001a6\t\u000b\u00194\u0001\u0019A6\u0011\u00071lw.D\u0001J\u0013\tq\u0017JA\u0003BeJ\f\u0017\u0010\u0005\u0002ma&\u0011\u0011/\u0013\u0002\u0007\t>,(\r\\3)\u0007\u0019\tv+A\btKR\fV/\u00198uS2,7oQ8m)\t!W\u000fC\u0003g\u000f\u0001\u00071\tK\u0002\b#^\u000bqb]3u\r&$\u0018J\u001c;fe\u000e,\u0007\u000f\u001e\u000b\u0003IfDQA\u001a\u0005A\u0002i\u0004\"\u0001\\>\n\u0005qL%a\u0002\"p_2,\u0017M\u001c\u0015\u0004\u0011E;\u0016AC:fi6\u000b\u00070\u0013;feR\u0019A-!\u0001\t\r\u0019L\u0001\u0019AA\u0002!\ra\u0017QA\u0005\u0004\u0003\u000fI%aA%oi\"\u001a\u0011\"U,\u0002\rM,G\u000fV8m)\r!\u0017q\u0002\u0005\u0006M*\u0001\ra\u001c\u0015\u0004\u0015E;\u0016aE:fi\u0006;wM]3hCRLwN\u001c#faRDGc\u00013\u0002\u0018!1am\u0003a\u0001\u0003\u0007ACaC)\u0002\u001c\u0005\u0012\u0011QD\u0001\u0006e9\nd\u0006M\u0001\u0014g\u0016$X*\u0019=CY>\u001c7nU5{K&sWJ\u0011\u000b\u0004I\u0006\r\u0002\"\u00024\r\u0001\u0004y\u0007\u0006\u0002\u0007R\u0003O\t#!!\u000b\u0002\u000bMr\u0013G\f\u0019\u0002\u000bQ\u0014\u0018-\u001b8\u0015\u0007=\ny\u0003C\u0004\u000225\u0001\r!a\r\u0002\u000f\u0011\fG/Y:fiB\"\u0011QGA#!\u0019\t9$!\u0010\u0002B5\u0011\u0011\u0011\b\u0006\u0004\u0003wa\u0012aA:rY&!\u0011qHA\u001d\u0005\u001d!\u0015\r^1tKR\u0004B!a\u0011\u0002F1\u0001A\u0001DA$\u0003_\t\t\u0011!A\u0003\u0002\u0005%#aA0%cE!\u00111JA)!\ra\u0017QJ\u0005\u0004\u0003\u001fJ%a\u0002(pi\"Lgn\u001a\t\u0004Y\u0006M\u0013bAA+\u0013\n\u0019\u0011I\\=\u0002\u0013Q\u0014\u0018-\u001b8J[BdGCDA.\u0003C\ni(!!\u0002\u0006\u0006%\u0015q\u0015\t\u0006Y\u0006u3n[\u0005\u0004\u0003?J%A\u0002+va2,'\u0007C\u0004\u0002d9\u0001\r!!\u001a\u0002\u0013%t7\u000f^1oG\u0016\u001c\bCBA4\u0003[\n\t(\u0004\u0002\u0002j)\u0019\u00111\u000e\u000f\u0002\u0007I$G-\u0003\u0003\u0002p\u0005%$a\u0001*E\tB!\u00111OA=\u001b\t\t)HC\u0002\u0002xi\tqAZ3biV\u0014X-\u0003\u0003\u0002|\u0005U$\u0001C%ogR\fgnY3\t\r\u0005}d\u00021\u0001p\u0003M\t7\r^;bY\ncwnY6TSj,\u0017J\\'C\u0011\u0019\t\u0019I\u0004a\u0001W\u0006Ya-Z1ukJ,7o\u0015;e\u0011\u0019\t9I\u0004a\u0001W\u0006aa-Z1ukJ,7/T3b]\"9\u00111\u0012\bA\u0002\u00055\u0015!C8qi&l\u0017N_3s!\u0019\ty)!'\u0002\u001e6\u0011\u0011\u0011\u0013\u0006\u0005\u0003'\u000b)*\u0001\u0005paRLW.\u001b>f\u0015\t\t9*\u0001\u0004ce\u0016,'0Z\u0005\u0005\u00037\u000b\tJA\u0003M\u0005\u001a;5\u000bE\u0003\u0002 \u0006\rv.\u0004\u0002\u0002\"*\u00191&!&\n\t\u0005\u0015\u0016\u0011\u0015\u0002\f\t\u0016t7/\u001a,fGR|'\u000f\u0003\u0004\u0002*:\u0001\ra[\u0001\u0010S:LG/[1m'>dW\u000f^5p]\u0006yAO]1og\u001a|'/\\*dQ\u0016l\u0017\r\u0006\u0003\u00020\u0006m\u0006\u0003BAY\u0003ok!!a-\u000b\t\u0005U\u0016\u0011H\u0001\u0006if\u0004Xm]\u0005\u0005\u0003s\u000b\u0019L\u0001\u0006TiJ,8\r\u001e+za\u0016Dq!!0\u0010\u0001\u0004\ty+\u0001\u0004tG\",W.\u0019\u0015\u0004\u001fE;\u0016\u0001B2paf$2ALAc\u0011\u001d\t9\r\u0005a\u0001\u0003\u0013\fQ!\u001a=ue\u0006\u0004B!a3\u0002R6\u0011\u0011Q\u001a\u0006\u0004\u0003\u001fT\u0012!\u00029be\u0006l\u0017\u0002BAj\u0003\u001b\u0014\u0001\u0002U1sC6l\u0015\r\u001d\u0015\u0004!E;\u0016!E3ti&l\u0017\r^3N_\u0012,GnU5{KR!\u00111\\Aq!\ra\u0017Q\\\u0005\u0004\u0003?L%\u0001\u0002'p]\u001eDq!!\r\u0012\u0001\u0004\t\u0019\u000f\r\u0003\u0002f\u0006%\bCBA\u001c\u0003{\t9\u000f\u0005\u0003\u0002D\u0005%H\u0001DAv\u0003C\f\t\u0011!A\u0003\u0002\u0005%#aA0%e!\u001a\u0001!U,\u0002+\u00053EkU;sm&4\u0018\r\u001c*fOJ,7o]5p]B\u0011QeE\n\b'\u0005U\u00181 B\u0001!\ra\u0017q_\u0005\u0004\u0003sL%AB!osJ+g\r\u0005\u00037\u0003{t\u0013bAA\u0000o\t)B)\u001a4bk2$\b+\u0019:b[N\u0014V-\u00193bE2,\u0007\u0003\u0002B\u0002\u0005\u001bi!A!\u0002\u000b\t\t\u001d!\u0011B\u0001\u0003S>T!Aa\u0003\u0002\t)\fg/Y\u0005\u0005\u0005\u001f\u0011)A\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002r\u0006!An\\1e)\rq#q\u0003\u0005\u0007\u00053)\u0002\u0019A\"\u0002\tA\fG\u000f\u001b\u0015\u0004+E;\u0016\u0001D<sSR,'+\u001a9mC\u000e,GC\u0001B\u0011!\u0011\u0011\u0019C!\u000b\u000e\u0005\t\u0015\"\u0002\u0002B\u0014\u0005\u0013\tA\u0001\\1oO&!!1\u0006B\u0013\u0005\u0019y%M[3di\"\u001a1#U,)\u0007I\tv\u000b")
public class AFTSurvivalRegression
extends Regressor<Vector, AFTSurvivalRegression, AFTSurvivalRegressionModel>
implements AFTSurvivalRegressionParams,
DefaultParamsWritable {
    private final String uid;
    private Param<String> censorCol;
    private DoubleArrayParam quantileProbabilities;
    private Param<String> quantilesCol;
    private DoubleParam maxBlockSizeInMB;
    private IntParam aggregationDepth;
    private BooleanParam fitIntercept;
    private DoubleParam tol;
    private IntParam maxIter;

    public static AFTSurvivalRegression load(String path) {
        return AFTSurvivalRegression$.MODULE$.load(path);
    }

    public static MLReader<AFTSurvivalRegression> read() {
        return AFTSurvivalRegression$.MODULE$.read();
    }

    @Override
    public MLWriter write() {
        return DefaultParamsWritable.write$(this);
    }

    @Override
    public void save(String path) throws IOException {
        MLWritable.save$(this, path);
    }

    @Override
    public String getCensorCol() {
        return AFTSurvivalRegressionParams.getCensorCol$(this);
    }

    @Override
    public double[] getQuantileProbabilities() {
        return AFTSurvivalRegressionParams.getQuantileProbabilities$(this);
    }

    @Override
    public String getQuantilesCol() {
        return AFTSurvivalRegressionParams.getQuantilesCol$(this);
    }

    @Override
    public boolean hasQuantilesCol() {
        return AFTSurvivalRegressionParams.hasQuantilesCol$(this);
    }

    @Override
    public StructType validateAndTransformSchema(StructType schema, boolean fitting) {
        return AFTSurvivalRegressionParams.validateAndTransformSchema$(this, schema, fitting);
    }

    @Override
    public final double getMaxBlockSizeInMB() {
        return HasMaxBlockSizeInMB.getMaxBlockSizeInMB$(this);
    }

    @Override
    public final int getAggregationDepth() {
        return HasAggregationDepth.getAggregationDepth$(this);
    }

    @Override
    public final boolean getFitIntercept() {
        return HasFitIntercept.getFitIntercept$(this);
    }

    @Override
    public final double getTol() {
        return HasTol.getTol$(this);
    }

    @Override
    public final int getMaxIter() {
        return HasMaxIter.getMaxIter$(this);
    }

    @Override
    public final Param<String> censorCol() {
        return this.censorCol;
    }

    @Override
    public final DoubleArrayParam quantileProbabilities() {
        return this.quantileProbabilities;
    }

    @Override
    public final Param<String> quantilesCol() {
        return this.quantilesCol;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$censorCol_$eq(Param<String> x$1) {
        this.censorCol = x$1;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantileProbabilities_$eq(DoubleArrayParam x$1) {
        this.quantileProbabilities = x$1;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantilesCol_$eq(Param<String> x$1) {
        this.quantilesCol = x$1;
    }

    @Override
    public final DoubleParam maxBlockSizeInMB() {
        return this.maxBlockSizeInMB;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasMaxBlockSizeInMB$_setter_$maxBlockSizeInMB_$eq(DoubleParam x$1) {
        this.maxBlockSizeInMB = x$1;
    }

    @Override
    public final IntParam aggregationDepth() {
        return this.aggregationDepth;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_$eq(IntParam x$1) {
        this.aggregationDepth = x$1;
    }

    @Override
    public final BooleanParam fitIntercept() {
        return this.fitIntercept;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam x$1) {
        this.fitIntercept = x$1;
    }

    @Override
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam x$1) {
        this.tol = x$1;
    }

    @Override
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam x$1) {
        this.maxIter = x$1;
    }

    @Override
    public String uid() {
        return this.uid;
    }

    public AFTSurvivalRegression setCensorCol(String value) {
        return (AFTSurvivalRegression)this.set(this.censorCol(), value);
    }

    public AFTSurvivalRegression setQuantileProbabilities(double[] value) {
        return (AFTSurvivalRegression)this.set(this.quantileProbabilities(), value);
    }

    public AFTSurvivalRegression setQuantilesCol(String value) {
        return (AFTSurvivalRegression)this.set(this.quantilesCol(), value);
    }

    public AFTSurvivalRegression setFitIntercept(boolean value) {
        return (AFTSurvivalRegression)this.set(this.fitIntercept(), BoxesRunTime.boxToBoolean((boolean)value));
    }

    public AFTSurvivalRegression setMaxIter(int value) {
        return (AFTSurvivalRegression)this.set(this.maxIter(), BoxesRunTime.boxToInteger((int)value));
    }

    public AFTSurvivalRegression setTol(double value) {
        return (AFTSurvivalRegression)this.set(this.tol(), BoxesRunTime.boxToDouble((double)value));
    }

    public AFTSurvivalRegression setAggregationDepth(int value) {
        return (AFTSurvivalRegression)this.set(this.aggregationDepth(), BoxesRunTime.boxToInteger((int)value));
    }

    public AFTSurvivalRegression setMaxBlockSizeInMB(double value) {
        return (AFTSurvivalRegression)this.set(this.maxBlockSizeInMB(), BoxesRunTime.boxToDouble((double)value));
    }

    @Override
    public AFTSurvivalRegressionModel train(Dataset<?> dataset) {
        return (AFTSurvivalRegressionModel)Instrumentation$.MODULE$.instrumented((Function1 & Serializable)instr -> {
            double[] initialSolution;
            LBFGS optimizer;
            Tuple2<double[], double[]> tuple2;
            instr.logPipelineStage(this);
            instr.logDataset(dataset);
            instr.logParams(this, (Seq<Param<?>>)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new Param[]{this.labelCol(), this.featuresCol(), this.censorCol(), this.predictionCol(), this.quantilesCol(), this.fitIntercept(), this.maxIter(), this.tol(), this.aggregationDepth(), this.maxBlockSizeInMB()}));
            instr.logNamedValue("quantileProbabilities.size", this.$(this.quantileProbabilities()).length);
            StorageLevel storageLevel = dataset.storageLevel();
            StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
            if (storageLevel == null ? storageLevel2 != null : !storageLevel.equals(storageLevel2)) {
                instr.logWarning((Function0<String>)(Function0 & Serializable)() -> "Input instances will be standardized, blockified to blocks, and then cached during training. Be careful of double caching!");
            }
            Column casted = functions$.MODULE$.col(this.$(this.censorCol())).cast((DataType)DoubleType$.MODULE$);
            Column validatedCensorCol = functions$.MODULE$.when(casted.isNull().$bar$bar((Object)casted.isNaN()), (Object)functions$.MODULE$.raise_error(functions$.MODULE$.lit((Object)"Censors MUST NOT be Null or NaN"))).when(casted.$eq$bang$eq((Object)BoxesRunTime.boxToInteger((int)0)).$amp$amp((Object)casted.$eq$bang$eq((Object)BoxesRunTime.boxToInteger((int)1))), (Object)functions$.MODULE$.raise_error(functions$.MODULE$.concat((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.lit((Object)"Censors MUST be in {0, 1}, but got "), casted})))).otherwise((Object)casted);
            RDD instances = dataset.select((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new Column[]{DatasetUtils$.MODULE$.checkRegressionLabels(this.$(this.labelCol())), validatedCensorCol, DatasetUtils$.MODULE$.checkNonNanVectors(this.$(this.featuresCol()))})).rdd().map((Function1 & Serializable)x0$1 -> {
                Some some;
                Row row = x0$1;
                if (row != null && !(some = Row$.MODULE$.unapplySeq(row)).isEmpty() && some.get() != null && ((SeqOps)some.get()).lengthCompare(3) == 0) {
                    Object l = ((SeqOps)some.get()).apply(0);
                    Object c = ((SeqOps)some.get()).apply(1);
                    Object v = ((SeqOps)some.get()).apply(2);
                    if (l instanceof Double) {
                        double d = BoxesRunTime.unboxToDouble((Object)l);
                        if (c instanceof Double) {
                            double d2 = BoxesRunTime.unboxToDouble((Object)c);
                            if (v instanceof Vector) {
                                Vector vector = (Vector)v;
                                return new Instance(d, d2, vector);
                            }
                        }
                    }
                }
                throw new MatchError((Object)row);
            }, ClassTag$.MODULE$.apply(Instance.class)).setName("training instances");
            SummarizerBuffer summarizer = (SummarizerBuffer)instances.treeAggregate((Object)Summarizer$.MODULE$.createSummarizerBuffer((Seq<String>)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"mean", "std", "count"})), (Function2 & Serializable)(c, i) -> c.add(i.features()), (Function2 & Serializable)(c1, c2) -> c1.merge((SummarizerBuffer)c2), BoxesRunTime.unboxToInt((Object)this.$(this.aggregationDepth())), ClassTag$.MODULE$.apply(SummarizerBuffer.class));
            double[] featuresMean = summarizer.mean().toArray();
            double[] featuresStd = summarizer.std().toArray();
            int numFeatures = featuresStd.length;
            instr.logNumFeatures(numFeatures);
            instr.logNumExamples(summarizer.count());
            double actualBlockSizeInMB = BoxesRunTime.unboxToDouble((Object)this.$(this.maxBlockSizeInMB()));
            if (actualBlockSizeInMB == 0.0) {
                actualBlockSizeInMB = InstanceBlock$.MODULE$.DefaultBlockSizeInMB();
                Predef$.MODULE$.require(actualBlockSizeInMB > 0.0, (Function0 & Serializable)() -> "inferred actual BlockSizeInMB must > 0");
                instr.logNamedValue("actualBlockSizeInMB", Double.toString(actualBlockSizeInMB));
            }
            if (!BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept())) && RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), numFeatures).exists((Function1)(JFunction1.mcZI.sp & Serializable)i -> featuresStd[i] == 0.0 && summarizer.mean().apply(i) != 0.0)) {
                instr.logWarning((Function0<String>)(Function0 & Serializable)() -> "Fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg.");
            }
            if ((tuple2 = this.trainImpl((RDD<Instance>)instances, actualBlockSizeInMB, featuresStd, featuresMean, (LBFGS<DenseVector<Object>>)(optimizer = new LBFGS(BoxesRunTime.unboxToInt((Object)this.$(this.maxIter())), 10, BoxesRunTime.unboxToDouble((Object)this.$(this.tol())), (MutableInnerProductModule)DenseVector$.MODULE$.space_Double())), initialSolution = (double[])Array$.MODULE$.ofDim(numFeatures + 2, (ClassTag)ClassTag$.MODULE$.Double()))) == null) {
                throw new MatchError(tuple2);
            }
            double[] rawCoefficients = (double[])tuple2._1();
            double[] objectiveHistory = (double[])tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)rawCoefficients, (Object)objectiveHistory);
            double[] rawCoefficients2 = (double[])tuple22._1();
            double[] objectiveHistory2 = (double[])tuple22._2();
            if (rawCoefficients2 == null) {
                MLUtils$.MODULE$.optimizerFailed((Instrumentation)instr, optimizer.getClass());
            }
            double[] coefficientArray = (double[])Array$.MODULE$.tabulate(numFeatures, (Function1)(JFunction1.mcDI.sp & Serializable)i -> {
                if (featuresStd[i] != 0.0) {
                    return rawCoefficients2[i] / featuresStd[i];
                }
                return 0.0;
            }, (ClassTag)ClassTag$.MODULE$.Double());
            Vector coefficients = Vectors$.MODULE$.dense(coefficientArray);
            double intercept = rawCoefficients2[numFeatures];
            double scale = package$.MODULE$.exp(rawCoefficients2[numFeatures + 1]);
            return new AFTSurvivalRegressionModel(this.uid(), coefficients, intercept, scale);
        });
    }

    private Tuple2<double[], double[]> trainImpl(RDD<Instance> instances, double actualBlockSizeInMB, double[] featuresStd, double[] featuresMean, LBFGS<DenseVector<Object>> optimizer, double[] initialSolution) {
        double[] solution;
        int numFeatures = featuresStd.length;
        double[] inverseStd = (double[])ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.doubleArrayOps(featuresStd), (Function1)(JFunction1.mcDD.sp & Serializable)std -> {
            if (std != 0.0) {
                return 1.0 / std;
            }
            return 0.0;
        }, (ClassTag)ClassTag$.MODULE$.Double());
        double[] scaledMean = (double[])Array$.MODULE$.tabulate(numFeatures, (Function1)(JFunction1.mcDI.sp & Serializable)i -> inverseStd[i] * featuresMean[i], (ClassTag)ClassTag$.MODULE$.Double());
        Broadcast bcInverseStd = instances.context().broadcast((Object)inverseStd, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
        Broadcast bcScaledMean = instances.context().broadcast((Object)scaledMean, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
        RDD scaled = instances.mapPartitions((Function1 & Serializable)iter -> {
            Function1<Vector, Vector> func = StandardScalerModel$.MODULE$.getTransformFunc((double[])Array$.MODULE$.empty((ClassTag)ClassTag$.MODULE$.Double()), (double[])bcInverseStd.value(), false, true);
            return iter.map((Function1 & Serializable)x0$1 -> {
                Instance instance = x0$1;
                if (instance != null) {
                    double label = instance.label();
                    double weight = instance.weight();
                    Vector vec = instance.features();
                    return new Instance(label, weight, (Vector)func.apply((Object)vec));
                }
                throw new MatchError((Object)instance);
            });
        }, instances.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Instance.class));
        long maxMemUsage = (long)RichDouble$.MODULE$.ceil$extension(Predef$.MODULE$.doubleWrapper(actualBlockSizeInMB * (double)1024L * (double)1024L));
        RDD blocks = InstanceBlock$.MODULE$.blokifyWithMaxMemUsage((RDD<Instance>)scaled, maxMemUsage).persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()).setName("training blocks (blockSizeInMB=" + actualBlockSizeInMB + ")");
        Function1 & Serializable getAggregatorFunc = (Function1 & Serializable)x$2 -> new AFTBlockAggregator((Broadcast<double[]>)bcScaledMean, BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept())), (Broadcast<Vector>)x$2);
        RDDLossFunction costFun = new RDDLossFunction(blocks, getAggregatorFunc, (Option<DifferentiableRegularization<Vector>>)None$.MODULE$, BoxesRunTime.unboxToInt((Object)this.$(this.aggregationDepth())), ClassTag$.MODULE$.apply(InstanceBlock.class), ClassTag$.MODULE$.apply(AFTBlockAggregator.class));
        if (BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept()))) {
            double adapt = BLAS$.MODULE$.javaBLAS().ddot(numFeatures, initialSolution, 1, scaledMean, 1);
            initialSolution[numFeatures] = initialSolution[numFeatures] + adapt;
        }
        Iterator states = optimizer.iterations((StochasticDiffFunction)new CachedDiffFunction(costFun, DenseVector$.MODULE$.canCopyDenseVector((ClassTag)ClassTag$.MODULE$.Double())), (Object)new DenseVector.mcD.sp(initialSolution));
        ArrayBuilder arrayBuilder = ArrayBuilder$.MODULE$.make((ClassTag)ClassTag$.MODULE$.Double());
        FirstOrderMinimizer.State state = null;
        while (states.hasNext()) {
            state = (FirstOrderMinimizer.State)states.next();
            arrayBuilder.$plus$eq((Object)BoxesRunTime.boxToDouble((double)state.adjustedValue()));
        }
        blocks.unpersist(blocks.unpersist$default$1());
        bcInverseStd.destroy();
        bcScaledMean.destroy();
        double[] dArray = solution = state == null ? null : ((DenseVector)state.x()).toArray$mcD$sp((ClassTag)ClassTag$.MODULE$.Double());
        if (BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept())) && solution != null) {
            double adapt = BLAS$.MODULE$.getBLAS(numFeatures).ddot(numFeatures, solution, 1, scaledMean, 1);
            solution[numFeatures] = solution[numFeatures] - adapt;
        }
        return new Tuple2((Object)solution, arrayBuilder.result());
    }

    @Override
    public StructType transformSchema(StructType schema) {
        return this.validateAndTransformSchema(schema, true);
    }

    @Override
    public AFTSurvivalRegression copy(ParamMap extra) {
        return (AFTSurvivalRegression)this.defaultCopy(extra);
    }

    @Override
    public long estimateModelSize(Dataset<?> dataset) {
        int numFeatures = DatasetUtils$.MODULE$.getNumFeatures(dataset, this.$(this.featuresCol()));
        long size = this.estimateMatadataSize();
        return size += Vectors$.MODULE$.getDenseSize((long)numFeatures);
    }

    public AFTSurvivalRegression(String uid) {
        this.uid = uid;
        HasMaxIter.$init$(this);
        HasTol.$init$(this);
        HasFitIntercept.$init$(this);
        HasAggregationDepth.$init$(this);
        HasMaxBlockSizeInMB.$init$(this);
        AFTSurvivalRegressionParams.$init$(this);
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$(this);
        Statics.releaseFence();
    }

    public AFTSurvivalRegression() {
        this(Identifiable$.MODULE$.randomUID("aftSurvReg"));
    }
}

