"""
LLM增强模块
支持多种LLM后端：Anthropic Claude、OpenAI、本地Ollama等
用于：
1. 生成丰富的零件语义描述
2. 提取专业关键词
3. 理解用户查询意图
4. 查询扩展和改写
"""

import json
import os
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class LLMConfig:
    """LLM配置"""
    provider: str = "anthropic"  # anthropic, openai, ollama, zhipu, qwen
    model: str = "claude-sonnet-4-20250514"
    api_key: str = ""
    base_url: str = ""
    temperature: float = 0.3
    max_tokens: int = 2000
    timeout: int = 30
    
    # 代理设置（如果需要）
    proxy: str = ""


class LLMProvider(ABC):
    """LLM提供者抽象基类"""
    
    @abstractmethod
    def generate(self, prompt: str, system_prompt: str = "") -> str:
        pass
    
    @abstractmethod
    def is_available(self) -> bool:
        pass


class AnthropicProvider(LLMProvider):
    """Anthropic Claude提供者"""
    
    def __init__(self, config: LLMConfig):
        self.config = config
        self.client = None
        self._init_client()
    
    def _init_client(self):
        try:
            import anthropic
            api_key = self.config.api_key or os.getenv("ANTHROPIC_API_KEY")
            if api_key:
                self.client = anthropic.Anthropic(api_key=api_key)
        except ImportError:
            logger.warning("anthropic库未安装，请运行: pip install anthropic")
    
    def is_available(self) -> bool:
        return self.client is not None
    
    def generate(self, prompt: str, system_prompt: str = "") -> str:
        if not self.client:
            return ""
        
        try:
            message = self.client.messages.create(
                model=self.config.model,
                max_tokens=self.config.max_tokens,
                temperature=self.config.temperature,
                system=system_prompt if system_prompt else "你是一个专业的机械工程师和CAD专家。",
                messages=[{"role": "user", "content": prompt}]
            )
            return message.content[0].text
        except Exception as e:
            logger.error(f"Anthropic API调用失败: {e}")
            return ""


class OpenAIProvider(LLMProvider):
    """OpenAI提供者"""
    
    def __init__(self, config: LLMConfig):
        self.config = config
        self.client = None
        self._init_client()
    
    def _init_client(self):
        try:
            from openai import OpenAI
            api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
            base_url = self.config.base_url or os.getenv("OPENAI_BASE_URL")
            
            if api_key:
                kwargs = {"api_key": api_key}
                if base_url:
                    kwargs["base_url"] = base_url
                self.client = OpenAI(**kwargs)
        except ImportError:
            logger.warning("openai库未安装，请运行: pip install openai")
    
    def is_available(self) -> bool:
        return self.client is not None
    
    def generate(self, prompt: str, system_prompt: str = "") -> str:
        if not self.client:
            return ""
        
        try:
            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": prompt})
            
            response = self.client.chat.completions.create(
                model=self.config.model if self.config.model != "claude-sonnet-4-20250514" else "gpt-4o",
                messages=messages,
                temperature=self.config.temperature,
                max_tokens=self.config.max_tokens
            )
            return response.choices[0].message.content
        except Exception as e:
            logger.error(f"OpenAI API调用失败: {e}")
            return ""


class OllamaProvider(LLMProvider):
    """Ollama本地LLM提供者"""
    
    def __init__(self, config: LLMConfig):
        self.config = config
        self.base_url = config.base_url or "http://localhost:11434"
        self.model = config.model if config.model != "claude-sonnet-4-20250514" else "qwen2.5:7b"
    
    def is_available(self) -> bool:
        try:
            import requests
            response = requests.get(f"{self.base_url}/api/tags", timeout=2)
            return response.status_code == 200
        except:
            return False
    
    def generate(self, prompt: str, system_prompt: str = "") -> str:
        try:
            import requests
            
            full_prompt = prompt
            if system_prompt:
                full_prompt = f"{system_prompt}\n\n{prompt}"
            
            response = requests.post(
                f"{self.base_url}/api/generate",
                json={
                    "model": self.model,
                    "prompt": full_prompt,
                    "stream": False,
                    "options": {
                        "temperature": self.config.temperature
                    }
                },
                timeout=self.config.timeout
            )
            
            if response.status_code == 200:
                return response.json().get("response", "")
            return ""
        except Exception as e:
            logger.error(f"Ollama API调用失败: {e}")
            return ""


class HTTPAPIProvider(LLMProvider):
    """通用HTTP API提供者（支持兼容OpenAI格式的API）"""
    
    def __init__(self, config: LLMConfig):
        self.config = config
        self.base_url = config.base_url
        self.api_key = config.api_key
    
    def is_available(self) -> bool:
        return bool(self.base_url and self.api_key)
    
    def generate(self, prompt: str, system_prompt: str = "") -> str:
        try:
            import requests
            
            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": prompt})
            
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.api_key}"
            }
            
            response = requests.post(
                f"{self.base_url}/v1/chat/completions",
                headers=headers,
                json={
                    "model": self.config.model,
                    "messages": messages,
                    "temperature": self.config.temperature,
                    "max_tokens": self.config.max_tokens
                },
                timeout=self.config.timeout
            )
            
            if response.status_code == 200:
                return response.json()["choices"][0]["message"]["content"]
            return ""
        except Exception as e:
            logger.error(f"HTTP API调用失败: {e}")
            return ""


def create_llm_provider(config: LLMConfig) -> LLMProvider:
    """工厂函数：创建LLM提供者"""
    providers = {
        "anthropic": AnthropicProvider,
        "openai": OpenAIProvider,
        "ollama": OllamaProvider,
        "http": HTTPAPIProvider,
    }
    
    provider_class = providers.get(config.provider, HTTPAPIProvider)
    return provider_class(config)


# ============== Prompt模板 ==============

PROMPTS = {
    "part_analysis": """分析以下CAD零件的特征数据，生成专业的语义描述。

## 零件信息
文件名: {filename}

## 特征列表
{features}

## 拓扑关系
{topology}

## 特征模式
{patterns}

请按以下格式输出（使用JSON）：
```json
{{
    "part_type": "零件类型（如：三通、螺钉、法兰等）",
    "part_category": "零件大类（如：管件、紧固件、连接件等）",
    "description": "详细的自然语言描述（100-200字）",
    "function": "零件的主要功能和用途",
    "keywords": ["关键词1", "关键词2", ...],
    "industry_terms": ["行业术语1", "行业术语2", ...],
    "similar_parts": ["相似零件1", "相似零件2", ...],
    "material_suggestion": "推荐材料",
    "manufacturing_process": "推荐加工工艺",
    "standards": ["相关标准（如GB、ISO等）"]
}}
```

注意：
1. 关键词要包含中英文、同义词、缩写
2. 描述要专业但易懂
3. 尽可能识别零件的规格型号（如M6、DN50等）
""",

    "query_understanding": """理解用户的零件搜索查询，提取搜索意图。

用户查询: "{query}"

请分析并输出JSON：
```json
{{
    "original_query": "原始查询",
    "intent": "搜索意图（如：找零件、问规格、比较等）",
    "part_types": ["可能的零件类型"],
    "features": ["提到的特征"],
    "specifications": ["规格参数（如M6、DN50等）"],
    "expanded_queries": ["扩展的搜索词"],
    "synonyms": ["同义词"],
    "filters": {{
        "size_range": null,
        "material": null,
        "standard": null
    }}
}}
```
""",

    "query_expansion": """为以下搜索查询生成扩展词和同义词。

查询: "{query}"

请列出：
1. 同义词（中英文）
2. 相关术语
3. 上位词（更通用的词）
4. 下位词（更具体的词）
5. 相关规格型号

输出JSON格式：
```json
{{
    "synonyms": ["同义词列表"],
    "related_terms": ["相关术语"],
    "broader_terms": ["上位词"],
    "narrower_terms": ["下位词"],
    "specifications": ["规格型号"]
}}
```
""",

    "feature_extraction": """从以下CAD特征描述中提取关键信息。

特征数据:
{feature_data}

请提取：
1. 几何特征类型和参数
2. 可能的功能用途
3. 相关的行业术语

输出JSON。
""",
}


class LLMEnhancedAnalyzer:
    """LLM增强的语义分析器"""
    
    def __init__(self, config: LLMConfig = None):
        self.config = config or LLMConfig()
        self.provider = create_llm_provider(self.config)
        self.cache = {}  # 简单缓存
        
        if not self.provider.is_available():
            logger.warning(f"LLM提供者 {self.config.provider} 不可用，将使用规则引擎")
    
    def _parse_json_response(self, response: str) -> Dict:
        """解析LLM返回的JSON"""
        if not response:
            return {}
        
        # 尝试提取JSON块
        json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
        if json_match:
            json_str = json_match.group(1)
        else:
            # 尝试直接解析
            json_str = response
        
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            # 尝试修复常见问题
            json_str = re.sub(r',\s*}', '}', json_str)
            json_str = re.sub(r',\s*]', ']', json_str)
            try:
                return json.loads(json_str)
            except:
                logger.warning("JSON解析失败")
                return {}
    
    def analyze_part(self, features_data: Dict) -> Dict:
        """使用LLM分析零件"""
        if not self.provider.is_available():
            return {}
        
        filename = features_data.get("file_info", {}).get("filename", "unknown")
        
        # 检查缓存
        cache_key = f"part_{filename}"
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        # 准备特征文本
        features = features_data.get("features", [])
        features_text = "\n".join(f"- {f}" for f in features[:20])  # 限制数量
        
        topology = features_data.get("topology_relations", [])
        topology_text = "\n".join(
            f"- {r.get('feature1_type', '')} 与 {r.get('feature2_type', '')} {r.get('relation_type', '')}"
            for r in topology[:15]
        )
        
        patterns = features_data.get("feature_patterns", [])
        patterns_text = "\n".join(
            f"- {p.get('type', '')}: {p.get('plane', '')}"
            for p in patterns[:10]
        )
        
        # 构建prompt
        prompt = PROMPTS["part_analysis"].format(
            filename=filename,
            features=features_text or "无",
            topology=topology_text or "无",
            patterns=patterns_text or "无"
        )
        
        # 调用LLM
        system_prompt = """你是一个资深的机械工程师和CAD专家，精通各类机械零件的识别和分析。
你需要根据CAD特征数据分析零件类型、功能和用途。
请使用专业但易懂的语言，并确保输出有效的JSON格式。"""
        
        response = self.provider.generate(prompt, system_prompt)
        result = self._parse_json_response(response)
        
        # 缓存结果
        if result:
            self.cache[cache_key] = result
        
        return result
    
    def understand_query(self, query: str) -> Dict:
        """理解用户查询意图"""
        if not self.provider.is_available():
            return {"original_query": query, "expanded_queries": [query]}
        
        # 检查缓存
        cache_key = f"query_{query}"
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        prompt = PROMPTS["query_understanding"].format(query=query)
        
        system_prompt = """你是一个机械零件搜索助手，帮助用户找到需要的零件。
请分析用户的搜索意图，提取关键信息，并生成搜索扩展词。"""
        
        response = self.provider.generate(prompt, system_prompt)
        result = self._parse_json_response(response)
        
        if result:
            self.cache[cache_key] = result
        
        return result
    
    def expand_query(self, query: str) -> List[str]:
        """扩展搜索查询"""
        if not self.provider.is_available():
            return [query]
        
        prompt = PROMPTS["query_expansion"].format(query=query)
        
        response = self.provider.generate(prompt)
        result = self._parse_json_response(response)
        
        expanded = [query]
        if result:
            expanded.extend(result.get("synonyms", []))
            expanded.extend(result.get("related_terms", []))
            expanded.extend(result.get("narrower_terms", []))
        
        return list(set(expanded))
    
    def generate_description(self, features_data: Dict, 
                            rule_based_result: Dict = None) -> str:
        """生成增强的零件描述"""
        llm_result = self.analyze_part(features_data)
        
        if llm_result and llm_result.get("description"):
            # 合并LLM结果和规则引擎结果
            description = llm_result["description"]
            
            if llm_result.get("function"):
                description += f" 主要功能：{llm_result['function']}"
            
            if llm_result.get("industry_terms"):
                terms = "、".join(llm_result["industry_terms"][:5])
                description += f" 相关术语：{terms}"
            
            return description
        
        # 回退到规则引擎结果
        if rule_based_result:
            return rule_based_result.get("description", "")
        
        return ""
    
    def extract_keywords(self, features_data: Dict,
                        rule_based_keywords: List[str] = None) -> List[str]:
        """提取增强的关键词"""
        llm_result = self.analyze_part(features_data)
        
        keywords = set(rule_based_keywords or [])
        
        if llm_result:
            keywords.update(llm_result.get("keywords", []))
            keywords.update(llm_result.get("industry_terms", []))
            keywords.update(llm_result.get("similar_parts", []))
            
            if llm_result.get("part_type"):
                keywords.add(llm_result["part_type"])
            if llm_result.get("part_category"):
                keywords.add(llm_result["part_category"])
        
        return list(keywords)


# ============== 测试代码 ==============

if __name__ == "__main__":
    print("=" * 60)
    print("LLM增强模块测试")
    print("=" * 60)
    
    # 测试配置
    config = LLMConfig(
        provider="anthropic",
        model="claude-sonnet-4-20250514",
        api_key=os.getenv("ANTHROPIC_API_KEY", ""),
    )
    
    analyzer = LLMEnhancedAnalyzer(config)
    
    print(f"\nLLM提供者: {config.provider}")
    print(f"是否可用: {analyzer.provider.is_available()}")
    
    if analyzer.provider.is_available():
        # 测试查询理解
        test_queries = ["三通", "M6内六角螺钉", "DN50法兰"]
        
        print("\n查询理解测试:")
        print("-" * 40)
        
        for query in test_queries:
            result = analyzer.understand_query(query)
            print(f"\n查询: {query}")
            if result:
                print(f"  意图: {result.get('intent', 'N/A')}")
                print(f"  零件类型: {result.get('part_types', [])}")
                print(f"  扩展词: {result.get('expanded_queries', [])[:5]}")
    else:
        print("\n提示: 设置环境变量以启用LLM功能:")
        print("  export ANTHROPIC_API_KEY=your_key")
        print("  # 或")
        print("  export OPENAI_API_KEY=your_key")
