/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.query.sqm.mutation.internal.inline;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.hibernate.engine.jdbc.spi.JdbcServices;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.internal.util.MutableObject;
import org.hibernate.internal.util.NullnessUtil;
import org.hibernate.internal.util.collections.CollectionHelper;
import org.hibernate.metamodel.mapping.BasicEntityIdentifierMapping;
import org.hibernate.metamodel.mapping.BasicValuedModelPart;
import org.hibernate.metamodel.mapping.EntityIdentifierMapping;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.mapping.SelectableConsumer;
import org.hibernate.persister.entity.EntityPersister;
import org.hibernate.query.SemanticException;
import org.hibernate.query.spi.DomainQueryExecutionContext;
import org.hibernate.query.spi.QueryParameterImplementor;
import org.hibernate.query.sqm.ComparisonOperator;
import org.hibernate.query.sqm.internal.DomainParameterXref;
import org.hibernate.query.sqm.internal.SqmJdbcExecutionContextAdapter;
import org.hibernate.query.sqm.internal.SqmUtil;
import org.hibernate.query.sqm.mutation.internal.MatchingIdSelectionHelper;
import org.hibernate.query.sqm.mutation.internal.UpdateHandler;
import org.hibernate.query.sqm.mutation.internal.inline.AbstractInlineHandler;
import org.hibernate.query.sqm.mutation.internal.inline.MatchingIdRestrictionProducer;
import org.hibernate.query.sqm.spi.SqmParameterMappingModelResolutionAccess;
import org.hibernate.query.sqm.sql.SqmTranslation;
import org.hibernate.query.sqm.sql.SqmTranslator;
import org.hibernate.query.sqm.tree.expression.SqmParameter;
import org.hibernate.query.sqm.tree.update.SqmUpdateStatement;
import org.hibernate.spi.NavigablePath;
import org.hibernate.sql.ast.SqlAstJoinType;
import org.hibernate.sql.ast.spi.SqlAliasBaseImpl;
import org.hibernate.sql.ast.spi.SqlSelection;
import org.hibernate.sql.ast.tree.MutationStatement;
import org.hibernate.sql.ast.tree.expression.ColumnReference;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.expression.SqlTuple;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.ast.tree.from.TableGroupJoin;
import org.hibernate.sql.ast.tree.from.TableGroupProducer;
import org.hibernate.sql.ast.tree.from.TableReference;
import org.hibernate.sql.ast.tree.from.TableReferenceJoin;
import org.hibernate.sql.ast.tree.from.UnionTableReference;
import org.hibernate.sql.ast.tree.from.ValuesTableGroup;
import org.hibernate.sql.ast.tree.insert.InsertSelectStatement;
import org.hibernate.sql.ast.tree.insert.Values;
import org.hibernate.sql.ast.tree.predicate.ComparisonPredicate;
import org.hibernate.sql.ast.tree.predicate.NullnessPredicate;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.sql.ast.tree.select.QuerySpec;
import org.hibernate.sql.ast.tree.select.SortSpecification;
import org.hibernate.sql.ast.tree.update.Assignment;
import org.hibernate.sql.ast.tree.update.UpdateStatement;
import org.hibernate.sql.exec.spi.ExecutionContext;
import org.hibernate.sql.exec.spi.JdbcOperationQueryMutation;
import org.hibernate.sql.exec.spi.JdbcParameterBindings;
import org.hibernate.sql.exec.spi.JdbcParametersList;
import org.hibernate.sql.results.internal.SqlSelectionImpl;

public class InlineUpdateHandler
extends AbstractInlineHandler
implements UpdateHandler {
    private final Map<QueryParameterImplementor<?>, Map<SqmParameter<?>, List<JdbcParametersList>>> jdbcParamsXref;
    private final Map<SqmParameter<?>, MappingModelExpressible<?>> resolvedParameterMappingModelTypes;
    private final List<TableUpdater> tableUpdaters;

    public InlineUpdateHandler(MatchingIdRestrictionProducer matchingIdsPredicateProducer, SqmUpdateStatement<?> sqmStatement, DomainParameterXref domainParameterXref, DomainQueryExecutionContext context, MutableObject<JdbcParameterBindings> firstJdbcParameterBindings) {
        super(matchingIdsPredicateProducer, sqmStatement, domainParameterXref, context, firstJdbcParameterBindings);
        domainParameterXref.clearExpansions();
        SessionFactoryImplementor sessionFactory = context.getSession().getFactory();
        SqmTranslator<? extends MutationStatement> translator = sessionFactory.getQueryEngine().getSqmTranslatorFactory().createMutationTranslator(sqmStatement, context.getQueryOptions(), domainParameterXref, context.getQueryParameterBindings(), context.getSession().getLoadQueryInfluencers(), sessionFactory.getSqlTranslationEngine());
        SqmTranslation<? extends MutationStatement> translation = translator.translate();
        TableGroup updatingTableGroup = ((UpdateStatement)translation.getSqlAst()).getFromClause().getRoots().get(0);
        HashMap<String, TableReference> tableReferenceByAlias = CollectionHelper.mapOfSize(updatingTableGroup.getTableReferenceJoins().size() + 1);
        this.collectTableReference(updatingTableGroup.getPrimaryTableReference(), tableReferenceByAlias::put);
        for (int i = 0; i < updatingTableGroup.getTableReferenceJoins().size(); ++i) {
            this.collectTableReference(updatingTableGroup.getTableReferenceJoins().get(i), tableReferenceByAlias::put);
        }
        HashMap<TableReference, ArrayList<Assignment>> assignmentsByTable = new HashMap<TableReference, ArrayList<Assignment>>();
        List<Assignment> assignments = ((UpdateStatement)translation.getSqlAst()).getAssignments();
        for (int i = 0; i < assignments.size(); ++i) {
            Assignment assignment = assignments.get(i);
            List<ColumnReference> assignmentColumnRefs = assignment.getAssignable().getColumnReferences();
            TableReference assignmentTableReference = null;
            for (int c = 0; c < assignmentColumnRefs.size(); ++c) {
                ColumnReference columnReference = assignmentColumnRefs.get(c);
                TableReference tableReference = this.resolveTableReference(columnReference, tableReferenceByAlias);
                if (assignmentTableReference != null && assignmentTableReference != tableReference) {
                    throw new SemanticException("Assignment referred to columns from multiple tables: " + String.valueOf(assignment.getAssignable()));
                }
                assignmentTableReference = tableReference;
            }
            ArrayList<Assignment> assignmentsForTable = (ArrayList<Assignment>)assignmentsByTable.get(assignmentTableReference);
            if (assignmentsForTable == null) {
                assignmentsForTable = new ArrayList<Assignment>();
                assignmentsByTable.put(assignmentTableReference, assignmentsForTable);
            }
            assignmentsForTable.add(assignment);
        }
        ArrayList<TableUpdater> tableUpdaters = new ArrayList<TableUpdater>();
        SqmJdbcExecutionContextAdapter executionContext = SqmJdbcExecutionContextAdapter.omittingLockingAndPaging(context);
        this.getEntityDescriptor().visitConstraintOrderedTables((tableExpression, tableKeyColumnVisitationSupplier) -> {
            TableUpdater tableUpdater = this.createTableUpdater(tableExpression, tableKeyColumnVisitationSupplier, this.getEntityDescriptor(), updatingTableGroup, assignmentsByTable, executionContext);
            if (tableUpdater != null) {
                tableUpdaters.add(tableUpdater);
            }
        });
        this.tableUpdaters = tableUpdaters;
        this.jdbcParamsXref = SqmUtil.generateJdbcParamsXref(domainParameterXref, translator);
        this.resolvedParameterMappingModelTypes = translation.getSqmParameterMappingModelTypeResolutions();
        JdbcParameterBindings jdbcParameterBindings = SqmUtil.createJdbcParameterBindings(context.getQueryParameterBindings(), domainParameterXref, this.jdbcParamsXref, new SqmParameterMappingModelResolutionAccess(){

            @Override
            public <T> MappingModelExpressible<T> getResolvedMappingModelType(SqmParameter<T> parameter) {
                return InlineUpdateHandler.this.resolvedParameterMappingModelTypes.get(parameter);
            }
        }, context.getSession());
        firstJdbcParameterBindings.get().visitBindings(jdbcParameterBindings::addBinding);
        firstJdbcParameterBindings.set(jdbcParameterBindings);
    }

    @Override
    public JdbcParameterBindings createJdbcParameterBindings(DomainQueryExecutionContext context) {
        JdbcParameterBindings jdbcParameterBindings = SqmUtil.createJdbcParameterBindings(context.getQueryParameterBindings(), this.getDomainParameterXref(), this.jdbcParamsXref, new SqmParameterMappingModelResolutionAccess(){

            @Override
            public <T> MappingModelExpressible<T> getResolvedMappingModelType(SqmParameter<T> parameter) {
                return InlineUpdateHandler.this.resolvedParameterMappingModelTypes.get(parameter);
            }
        }, context.getSession());
        super.createJdbcParameterBindings(context).visitBindings(jdbcParameterBindings::addBinding);
        return jdbcParameterBindings;
    }

    @Override
    public int execute(JdbcParameterBindings jdbcParameterBindings, DomainQueryExecutionContext executionContext) {
        List<Object> ids = MatchingIdSelectionHelper.selectMatchingIds(this.getMatchingIdsInterpretation(), jdbcParameterBindings, executionContext);
        if (ids == null || ids.isEmpty()) {
            return 0;
        }
        List<Expression> inListExpressions = this.getMatchingIdsPredicateProducer().produceIdExpressionList(ids, this.getEntityDescriptor());
        int rows = ids.size();
        SqmJdbcExecutionContextAdapter executionContextAdapter = SqmJdbcExecutionContextAdapter.omittingLockingAndPaging(executionContext);
        for (TableUpdater tableUpdater : this.tableUpdaters) {
            this.updateTable(tableUpdater, inListExpressions, rows, jdbcParameterBindings, executionContextAdapter);
        }
        return rows;
    }

    protected List<TableUpdater> getTableUpdaters() {
        return this.tableUpdaters;
    }

    private TableUpdater createTableUpdater(String tableExpression, Supplier<Consumer<SelectableConsumer>> tableKeyColumnVisitationSupplier, EntityPersister entityDescriptor, TableGroup updatingTableGroup, Map<TableReference, List<Assignment>> assignmentsByTable, ExecutionContext executionContext) {
        InsertSelectStatement insertSqlAst;
        Expression keyExpression;
        TableReference updatingTableReference = updatingTableGroup.getTableReference(updatingTableGroup.getNavigablePath(), tableExpression, false);
        List<Assignment> assignments = assignmentsByTable.get(updatingTableReference);
        if (assignments == null || assignments.isEmpty()) {
            return null;
        }
        EntityIdentifierMapping identifierMapping = entityDescriptor.getIdentifierMapping();
        int idColumnCount = identifierMapping.getJdbcTypeCount();
        assert (idColumnCount > 0);
        if (idColumnCount == 1) {
            BasicValuedModelPart basicIdMapping = NullnessUtil.castNonNull(identifierMapping.asBasicValuedModelPart());
            String idColumn = basicIdMapping.getSelectionExpression();
            keyExpression = new ColumnReference(updatingTableReference, idColumn, false, null, basicIdMapping.getJdbcMapping());
        } else {
            ArrayList columnReferences = new ArrayList(idColumnCount);
            tableKeyColumnVisitationSupplier.get().accept((columnIndex, selection) -> columnReferences.add(new ColumnReference(updatingTableReference, selection)));
            keyExpression = new SqlTuple(columnReferences, identifierMapping);
        }
        NamedTableReference dmlTableReference = this.resolveUnionTableReference(updatingTableReference, tableExpression);
        UpdateStatement sqlAst = new UpdateStatement(dmlTableReference, assignments, null);
        SessionFactoryImplementor sessionFactory = executionContext.getSession().getFactory();
        JdbcServices jdbcServices = sessionFactory.getJdbcServices();
        EntityPersister entityPersister = entityDescriptor.getEntityPersister();
        boolean isNullable = false;
        for (int i2 = 0; i2 < entityPersister.getTableSpan(); ++i2) {
            if (!tableExpression.equals(entityPersister.getTableName(i2)) || !entityPersister.isNullableTable(i2)) continue;
            isNullable = true;
            break;
        }
        if (isNullable) {
            ComparisonPredicate joinPredicate;
            List<String> columnNames;
            QuerySpec querySpec = new QuerySpec(true);
            NavigablePath valuesPath = new NavigablePath("id");
            TableGroup rootTableGroup = entityDescriptor.createRootTableGroup(true, updatingTableGroup.getNavigablePath(), updatingTableGroup.getSourceAlias(), new SqlAliasBaseImpl(updatingTableGroup.getGroupAlias()), () -> predicate -> {}, null);
            if (keyExpression instanceof SqlTuple) {
                List<? extends Expression> expressions = ((SqlTuple)keyExpression).getExpressions();
                ArrayList lhs = new ArrayList(expressions.size());
                ArrayList rhs = new ArrayList(expressions.size());
                columnNames = new ArrayList<String>(expressions.size());
                entityDescriptor.getIdentifierMapping().forEachSelectable((i, selectableMapping) -> {
                    Expression expression = (Expression)expressions.get(i);
                    ColumnReference columnReference = expression.getColumnReference();
                    ColumnReference valuesColumnReference = new ColumnReference(valuesPath.getLocalName(), columnReference.getColumnExpression(), false, null, columnReference.getJdbcMapping());
                    columnNames.add(columnReference.getColumnExpression());
                    lhs.add(valuesColumnReference);
                    rhs.add(new ColumnReference(rootTableGroup.getPrimaryTableReference(), selectableMapping.getSelectionExpression(), false, null, columnReference.getJdbcMapping()));
                    querySpec.getSelectClause().addSqlSelection(new SqlSelectionImpl(valuesColumnReference));
                });
                joinPredicate = new ComparisonPredicate(new SqlTuple(lhs, entityDescriptor.getIdentifierMapping()), ComparisonOperator.EQUAL, new SqlTuple(rhs, entityDescriptor.getIdentifierMapping()));
            } else {
                ColumnReference columnReference = keyExpression.getColumnReference();
                ColumnReference valuesColumnReference = new ColumnReference(valuesPath.getLocalName(), columnReference.getColumnExpression(), false, null, columnReference.getJdbcMapping());
                columnNames = Collections.singletonList(columnReference.getColumnExpression());
                joinPredicate = new ComparisonPredicate(valuesColumnReference, ComparisonOperator.EQUAL, new ColumnReference(rootTableGroup.getPrimaryTableReference(), ((BasicEntityIdentifierMapping)entityDescriptor.getIdentifierMapping()).getSelectionExpression(), false, null, columnReference.getJdbcMapping()));
                querySpec.getSelectClause().addSqlSelection(new SqlSelectionImpl(valuesColumnReference));
            }
            ValuesTableGroup valuesTableGroup = new ValuesTableGroup(valuesPath, null, new ArrayList<Values>(), valuesPath.getLocalName(), columnNames, true, sessionFactory);
            valuesTableGroup.addNestedTableGroupJoin(new TableGroupJoin(rootTableGroup.getNavigablePath(), SqlAstJoinType.LEFT, rootTableGroup, joinPredicate));
            querySpec.getFromClause().addRoot(valuesTableGroup);
            querySpec.applyPredicate(new NullnessPredicate(new ColumnReference(rootTableGroup.resolveTableReference(tableExpression), columnNames.get(0), entityDescriptor.getIdentifierMapping().getSingleJdbcMapping())));
            ArrayList<? extends Expression> targetColumnReferences = new ArrayList<Expression>();
            if (keyExpression instanceof SqlTuple) {
                targetColumnReferences.addAll(((SqlTuple)keyExpression).getExpressions());
            } else {
                targetColumnReferences.add((ColumnReference)keyExpression);
            }
            for (Assignment assignment : assignments) {
                targetColumnReferences.addAll(assignment.getAssignable().getColumnReferences());
                querySpec.getSelectClause().addSqlSelection(new SqlSelectionImpl(assignment.getAssignedValue()));
            }
            insertSqlAst = new InsertSelectStatement(dmlTableReference);
            insertSqlAst.addTargetColumnReferences(targetColumnReferences.toArray(new ColumnReference[0]));
            insertSqlAst.setSourceSelectStatement(querySpec);
        } else {
            insertSqlAst = null;
        }
        return new TableUpdater(sqlAst, insertSqlAst, tableKeyColumnVisitationSupplier);
    }

    protected JdbcOperationQueryMutation createTableUpdate(TableUpdater tableUpdater, List<Expression> inListExpressions, JdbcParameterBindings jdbcParameterBindings, ExecutionContext executionContext) {
        UpdateStatement updateStatement = new UpdateStatement(tableUpdater.updateStatement, tableUpdater.updateStatement.getTargetTable(), tableUpdater.updateStatement.getFromClause(), tableUpdater.updateStatement.getAssignments(), Predicate.combinePredicates(tableUpdater.updateStatement.getRestriction(), this.getMatchingIdsPredicateProducer().produceRestriction(inListExpressions, this.getEntityDescriptor(), 0, null, tableUpdater.updateStatement.getTargetTable(), tableUpdater.tableKeyColumnVisitationSupplier, executionContext)), tableUpdater.updateStatement.getReturningColumns());
        SessionFactoryImplementor sessionFactory = executionContext.getSession().getFactory();
        JdbcServices jdbcServices = sessionFactory.getJdbcServices();
        return jdbcServices.getJdbcEnvironment().getSqlAstTranslatorFactory().buildMutationTranslator(sessionFactory, updateStatement).translate(jdbcParameterBindings, executionContext.getQueryOptions());
    }

    protected JdbcOperationQueryMutation createTableInsert(TableUpdater tableUpdater, List<Expression> inListExpressions, JdbcParameterBindings jdbcParameterBindings, ExecutionContext executionContext) {
        SessionFactoryImplementor sessionFactory = executionContext.getSession().getFactory();
        InsertSelectStatement insertStatement = new InsertSelectStatement(tableUpdater.nullableInsert, tableUpdater.nullableInsert.getTargetTable(), tableUpdater.nullableInsert.getReturningColumns());
        QuerySpec originalQuerySpec = (QuerySpec)tableUpdater.nullableInsert.getSourceSelectStatement();
        assert (originalQuerySpec.getFromClause().getRoots().size() == 1);
        QuerySpec querySpec = new QuerySpec(true, 1);
        querySpec.getSelectClause().makeDistinct(originalQuerySpec.getSelectClause().isDistinct());
        for (SqlSelection sqlSelection : originalQuerySpec.getSelectClause().getSqlSelections()) {
            querySpec.getSelectClause().addSqlSelection(sqlSelection);
        }
        querySpec.applyPredicate(originalQuerySpec.getWhereClauseRestrictions());
        querySpec.setGroupByClauseExpressions(originalQuerySpec.getGroupByClauseExpressions());
        querySpec.setHavingClauseRestrictions(originalQuerySpec.getHavingClauseRestrictions());
        for (SortSpecification sortSpecification : originalQuerySpec.getSortSpecifications()) {
            querySpec.addSortSpecification(sortSpecification);
        }
        querySpec.setOffsetClauseExpression(originalQuerySpec.getOffsetClauseExpression());
        querySpec.setFetchClauseExpression(originalQuerySpec.getFetchClauseExpression(), originalQuerySpec.getFetchClauseType());
        ArrayList<Values> valuesList = new ArrayList<Values>(inListExpressions.size());
        for (Expression inListExpression : inListExpressions) {
            if (inListExpression instanceof SqlTuple) {
                valuesList.add(new Values(((SqlTuple)inListExpression).getExpressions()));
                continue;
            }
            valuesList.add(new Values(Collections.singletonList(inListExpression)));
        }
        ValuesTableGroup valuesTableGroup = (ValuesTableGroup)originalQuerySpec.getFromClause().getRoots().get(0);
        ValuesTableGroup valuesTableGroup2 = new ValuesTableGroup(valuesTableGroup.getNavigablePath(), (TableGroupProducer)valuesTableGroup.getModelPart(), valuesList, valuesTableGroup.getNavigablePath().getLocalName(), valuesTableGroup.getPrimaryTableReference().getColumnNames(), valuesTableGroup.canUseInnerJoins(), sessionFactory);
        valuesTableGroup2.addNestedTableGroupJoin(valuesTableGroup.getNestedTableGroupJoins().get(0));
        querySpec.getFromClause().addRoot(valuesTableGroup2);
        insertStatement.addTargetColumnReferences(tableUpdater.nullableInsert.getTargetColumns());
        insertStatement.setSourceSelectStatement(querySpec);
        JdbcServices jdbcServices = sessionFactory.getJdbcServices();
        return jdbcServices.getJdbcEnvironment().getSqlAstTranslatorFactory().buildMutationTranslator(sessionFactory, insertStatement).translate(jdbcParameterBindings, executionContext.getQueryOptions());
    }

    private void updateTable(TableUpdater tableUpdater, List<Expression> inListExpressions, int expectedUpdateCount, JdbcParameterBindings jdbcParameterBindings, ExecutionContext executionContext) {
        SessionFactoryImplementor sessionFactory = executionContext.getSession().getFactory();
        JdbcServices jdbcServices = sessionFactory.getJdbcServices();
        int updateCount = jdbcServices.getJdbcMutationExecutor().execute(this.createTableUpdate(tableUpdater, inListExpressions, jdbcParameterBindings, executionContext), jdbcParameterBindings, sql -> executionContext.getSession().getJdbcCoordinator().getStatementPreparer().prepareStatement((String)sql), (integer, preparedStatement) -> {}, executionContext);
        if (updateCount == expectedUpdateCount) {
            return;
        }
        if (tableUpdater.nullableInsert != null) {
            int insertCount = jdbcServices.getJdbcMutationExecutor().execute(this.createTableInsert(tableUpdater, inListExpressions, jdbcParameterBindings, executionContext), jdbcParameterBindings, sql -> executionContext.getSession().getJdbcCoordinator().getStatementPreparer().prepareStatement((String)sql), (integer, preparedStatement) -> {}, executionContext);
            assert (insertCount + updateCount == expectedUpdateCount);
        }
    }

    private void collectTableReference(TableReference tableReference, BiConsumer<String, TableReference> consumer) {
        consumer.accept(tableReference.getIdentificationVariable(), tableReference);
    }

    private void collectTableReference(TableReferenceJoin tableReferenceJoin, BiConsumer<String, TableReference> consumer) {
        this.collectTableReference(tableReferenceJoin.getJoinedTableReference(), consumer);
    }

    private TableReference resolveTableReference(ColumnReference columnReference, Map<String, TableReference> tableReferenceByAlias) {
        TableReference tableReferenceByQualifier = tableReferenceByAlias.get(columnReference.getQualifier());
        if (tableReferenceByQualifier != null) {
            return tableReferenceByQualifier;
        }
        throw new SemanticException("Assignment referred to column of a joined association: " + String.valueOf(columnReference));
    }

    private NamedTableReference resolveUnionTableReference(TableReference tableReference, String tableExpression) {
        if (tableReference instanceof UnionTableReference) {
            return new NamedTableReference(tableExpression, tableReference.getIdentificationVariable(), tableReference.isOptional());
        }
        return (NamedTableReference)tableReference;
    }

    protected record TableUpdater(UpdateStatement updateStatement, @Nullable InsertSelectStatement nullableInsert, Supplier<Consumer<SelectableConsumer>> tableKeyColumnVisitationSupplier) {
    }
}

