/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.mcp.annotation.spring;

import io.modelcontextprotocol.spec.McpSchema;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springaicommunity.mcp.annotation.McpElicitation;
import org.springaicommunity.mcp.annotation.McpLogging;
import org.springaicommunity.mcp.annotation.McpProgress;
import org.springaicommunity.mcp.annotation.McpPromptListChanged;
import org.springaicommunity.mcp.annotation.McpResourceListChanged;
import org.springaicommunity.mcp.annotation.McpSampling;
import org.springaicommunity.mcp.annotation.McpToolListChanged;
import org.springframework.aop.framework.autoproxy.AutoProxyUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ReflectionUtils;

abstract class AbstractClientMcpHandlerRegistry
implements BeanFactoryPostProcessor {
    protected Map<String, McpSchema.ClientCapabilities> capabilitiesPerClient = new HashMap<String, McpSchema.ClientCapabilities>();
    protected ConfigurableListableBeanFactory beanFactory;
    protected final Set<String> allAnnotatedBeans = new HashSet<String>();
    static final Class<? extends Annotation>[] CLIENT_MCP_ANNOTATIONS = new Class[]{McpSampling.class, McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, McpPromptListChanged.class, McpResourceListChanged.class};
    static final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, null);

    AbstractClientMcpHandlerRegistry() {
    }

    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        this.beanFactory = beanFactory;
        HashMap<String, List> elicitationClientToAnnotatedBeans = new HashMap<String, List>();
        HashMap<String, List> samplingClientToAnnotatedBeans = new HashMap<String, List>();
        for (String beanName : beanFactory.getBeanDefinitionNames()) {
            Class beanClass;
            if (!beanFactory.getBeanDefinition(beanName).isSingleton() || (beanClass = AutoProxyUtils.determineTargetClass((ConfigurableListableBeanFactory)beanFactory, (String)beanName)) == null) continue;
            List<Annotation> foundAnnotations = this.scan(beanClass);
            if (!foundAnnotations.isEmpty()) {
                this.allAnnotatedBeans.add(beanName);
            }
            for (Annotation foundAnnotation : foundAnnotations) {
                if (foundAnnotation instanceof McpSampling) {
                    McpSampling sampling = (McpSampling)foundAnnotation;
                    for (String client : sampling.clients()) {
                        samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList()).add(beanName);
                    }
                    continue;
                }
                if (!(foundAnnotation instanceof McpElicitation)) continue;
                McpElicitation elicitation = (McpElicitation)foundAnnotation;
                for (String client : elicitation.clients()) {
                    elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList()).add(beanName);
                }
            }
        }
        for (Map.Entry entry2 : elicitationClientToAnnotatedBeans.entrySet()) {
            if (((List)entry2.getValue()).size() <= 1) continue;
            throw new IllegalArgumentException("Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client".formatted(entry2.getKey(), new LinkedHashSet((Collection)entry2.getValue())));
        }
        for (Map.Entry entry3 : samplingClientToAnnotatedBeans.entrySet()) {
            if (((List)entry3.getValue()).size() <= 1) continue;
            throw new IllegalArgumentException("Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client".formatted(entry3.getKey(), new LinkedHashSet((Collection)entry3.getValue())));
        }
        HashMap<String, McpSchema.ClientCapabilities.Builder> capsPerClient = new HashMap<String, McpSchema.ClientCapabilities.Builder>();
        for (String samplingClient : samplingClientToAnnotatedBeans.keySet()) {
            capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling();
        }
        for (String elicitationClient : elicitationClientToAnnotatedBeans.keySet()) {
            capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()).elicitation();
        }
        this.capabilitiesPerClient = capsPerClient.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> ((McpSchema.ClientCapabilities.Builder)entry.getValue()).build()));
    }

    protected List<Annotation> scan(Class<?> beanClass) {
        ArrayList<Annotation> foundAnnotations = new ArrayList<Annotation>();
        ReflectionUtils.doWithMethods(beanClass, method -> {
            for (Class<? extends Annotation> annotationType : CLIENT_MCP_ANNOTATIONS) {
                Annotation annotation = AnnotationUtils.findAnnotation((Method)method, annotationType);
                if (annotation == null) continue;
                foundAnnotations.add(annotation);
            }
        });
        return foundAnnotations;
    }

    protected Map<Class<? extends Annotation>, Set<Object>> getBeansByAnnotationType() {
        HashMap<Class<? extends Annotation>, Set<Object>> beansByAnnotation = new HashMap<Class<? extends Annotation>, Set<Object>>();
        for (Class<? extends Annotation> annotation : CLIENT_MCP_ANNOTATIONS) {
            beansByAnnotation.put(annotation, new HashSet());
        }
        for (String beanName : this.allAnnotatedBeans) {
            Object bean = this.beanFactory.getBean(beanName);
            List<Annotation> annotations = this.scan(bean.getClass());
            for (Annotation annotation : annotations) {
                beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet()).add(bean);
            }
        }
        return beansByAnnotation;
    }
}

