"""
LLM增强版三维模型语义搜索 API 服务
支持智能搜索、对话式交互、查询意图理解
"""

from fastapi import FastAPI, HTTPException, Query, UploadFile, File, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import json
import os
from pathlib import Path
import uvicorn

# 导入模块
from feature_semantic_analyzer import FeatureSemanticAnalyzer
from llm_enhanced import LLMConfig, LLMEnhancedAnalyzer
from llm_search_engine import LLMEnhancedSearchEngine, create_enhanced_search_engine

app = FastAPI(
    title="三维模型智能搜索API (LLM增强版)",
    description="基于LLM的CAD模型智能检索系统，支持自然语言查询和对话式交互",
    version="2.0.0"
)

# CORS配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 全局实例
analyzer = FeatureSemanticAnalyzer()
llm_analyzer: LLMEnhancedAnalyzer = None
search_engine: LLMEnhancedSearchEngine = None

# 配置
DATA_DIR = Path("./data")
INDEX_FILE = DATA_DIR / "enhanced_search_index.json"


# ============== 配置模型 ==============

class LLMConfigModel(BaseModel):
    provider: str = Field(default="anthropic", description="LLM提供者")
    model: str = Field(default="claude-sonnet-4-20250514", description="模型名称")
    api_key: str = Field(default="", description="API密钥")
    base_url: str = Field(default="", description="API基础URL")
    temperature: float = Field(default=0.3, ge=0, le=1)


class SearchRequest(BaseModel):
    query: str = Field(..., description="搜索查询词", example="三通管件")
    top_k: int = Field(default=10, ge=1, le=100)
    explain: bool = Field(default=True, description="是否生成LLM解释")


class ConversationMessage(BaseModel):
    role: str = Field(..., description="角色: user/assistant")
    content: str = Field(..., description="消息内容")


class ConversationRequest(BaseModel):
    message: str = Field(..., description="用户消息")
    history: List[ConversationMessage] = Field(default=[], description="对话历史")


class SearchResultItem(BaseModel):
    file_id: str
    filename: str
    part_type: str
    score: float
    match_type: str
    highlights: List[str]
    description: str
    llm_explanation: str = ""
    related_parts: List[str] = []


class SearchResponse(BaseModel):
    query: str
    query_understanding: Dict[str, Any] = {}
    total_results: int
    results: List[SearchResultItem]


class ConversationResponse(BaseModel):
    message: str
    results: List[SearchResultItem] = []
    suggestions: List[str] = []


# ============== 初始化 ==============

def get_llm_config() -> LLMConfig:
    """获取LLM配置"""
    return LLMConfig(
        provider=os.getenv("LLM_PROVIDER", "anthropic"),
        model=os.getenv("LLM_MODEL", "claude-sonnet-4-20250514"),
        api_key=os.getenv("ANTHROPIC_API_KEY", os.getenv("OPENAI_API_KEY", "")),
        base_url=os.getenv("LLM_BASE_URL", ""),
        temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
    )


def init_engine():
    """初始化搜索引擎"""
    global search_engine, llm_analyzer
    
    llm_config = get_llm_config()
    
    # 初始化LLM分析器
    llm_analyzer = LLMEnhancedAnalyzer(llm_config)
    
    # 初始化搜索引擎
    search_engine = LLMEnhancedSearchEngine(llm_config)
    
    # 加载已有索引
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    
    if INDEX_FILE.exists():
        try:
            search_engine.load_index(str(INDEX_FILE))
            print(f"✓ 已加载索引，包含 {len(search_engine.models)} 个模型")
        except Exception as e:
            print(f"✗ 加载索引失败: {e}")
    
    # 索引数据目录中的文件
    json_files = list(DATA_DIR.glob("*_features.json"))
    if json_files:
        for json_file in json_files:
            if str(json_file.stem) not in search_engine.models:
                try:
                    with open(json_file, 'r', encoding='utf-8') as f:
                        features_data = json.load(f)
                    
                    semantics = analyzer.analyze_json(str(json_file))
                    rule_based = analyzer.to_dict(semantics)
                    search_engine.index_model_with_llm(features_data, rule_based)
                except Exception as e:
                    print(f"✗ 索引失败: {json_file} - {e}")
        
        search_engine.build_vector_index()
    
    return search_engine


# ============== API 端点 ==============

@app.get("/", tags=["健康检查"])
async def root():
    """API健康检查"""
    llm_status = "可用" if (llm_analyzer and llm_analyzer.provider.is_available()) else "不可用"
    return {
        "service": "三维模型智能搜索API",
        "version": "2.0.0",
        "llm_status": llm_status,
        "models_count": len(search_engine.models) if search_engine else 0
    }


@app.post("/search", response_model=SearchResponse, tags=["智能搜索"])
async def smart_search(request: SearchRequest):
    """
    智能搜索（LLM增强）
    
    特性：
    - 查询意图理解
    - 自动查询扩展
    - 同义词匹配
    - 结果解释生成
    """
    if not search_engine:
        raise HTTPException(status_code=500, detail="搜索引擎未初始化")
    
    # 查询理解
    query_understanding = {}
    if llm_analyzer and llm_analyzer.provider.is_available():
        query_understanding = llm_analyzer.understand_query(request.query)
    
    # 智能搜索
    results = search_engine.smart_search(
        query=request.query,
        top_k=request.top_k,
        explain=request.explain
    )
    
    return SearchResponse(
        query=request.query,
        query_understanding=query_understanding,
        total_results=len(results),
        results=[
            SearchResultItem(
                file_id=r.file_id,
                filename=r.filename,
                part_type=r.part_type,
                score=r.score,
                match_type=r.match_type,
                highlights=r.highlights,
                description=r.description,
                llm_explanation=r.llm_explanation,
                related_parts=r.related_parts
            ) for r in results
        ]
    )


@app.get("/search", response_model=SearchResponse, tags=["智能搜索"])
async def smart_search_get(
    query: str = Query(..., description="搜索查询词"),
    top_k: int = Query(default=10, ge=1, le=100),
    explain: bool = Query(default=True)
):
    """GET方式智能搜索"""
    return await smart_search(SearchRequest(query=query, top_k=top_k, explain=explain))


@app.post("/chat", response_model=ConversationResponse, tags=["对话式搜索"])
async def conversational_search(request: ConversationRequest):
    """
    对话式搜索
    
    支持自然语言交互，理解上下文
    """
    if not search_engine:
        raise HTTPException(status_code=500, detail="搜索引擎未初始化")
    
    history = [{"role": m.role, "content": m.content} for m in request.history]
    
    result = search_engine.conversational_search(request.message, history)
    
    return ConversationResponse(
        message=result["message"],
        results=[
            SearchResultItem(
                file_id=r.file_id,
                filename=r.filename,
                part_type=r.part_type,
                score=r.score,
                match_type=r.match_type,
                highlights=r.highlights,
                description=r.description,
                llm_explanation=getattr(r, 'llm_explanation', ''),
                related_parts=getattr(r, 'related_parts', [])
            ) for r in result.get("results", [])
        ],
        suggestions=result.get("suggestions", [])
    )


@app.post("/analyze", tags=["分析"])
async def analyze_and_index(
    file: UploadFile = File(...),
    background_tasks: BackgroundTasks = None
):
    """
    上传并分析JSON特征文件（LLM增强）
    
    自动：
    1. 识别零件类型
    2. 生成语义描述
    3. 提取行业术语
    4. 添加到搜索索引
    """
    if not file.filename.endswith('.json'):
        raise HTTPException(status_code=400, detail="只支持JSON文件")
    
    try:
        content = await file.read()
        features_data = json.loads(content.decode('utf-8'))
        
        # 保存文件
        DATA_DIR.mkdir(parents=True, exist_ok=True)
        save_path = DATA_DIR / file.filename
        with open(save_path, 'wb') as f:
            f.write(content)
        
        # 规则引擎分析
        semantics = analyzer.analyze_json(str(save_path))
        rule_based = analyzer.to_dict(semantics)
        
        # LLM增强索引
        enhanced = search_engine.index_model_with_llm(features_data, rule_based)
        search_engine.build_vector_index()
        
        # 保存索引
        search_engine.save_index(str(INDEX_FILE))
        
        return {
            "success": True,
            "file_id": enhanced.get("file_id", ""),
            "filename": enhanced.get("filename", ""),
            "part_type": enhanced.get("part_type", ""),
            "part_category": enhanced.get("part_category", ""),
            "description": enhanced.get("description", ""),
            "function": enhanced.get("function", ""),
            "keywords": enhanced.get("keywords", [])[:20],
            "similar_parts": enhanced.get("similar_parts", []),
            "standards": enhanced.get("standards", []),
            "llm_enhanced": llm_analyzer.provider.is_available() if llm_analyzer else False
        }
        
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="无效的JSON格式")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")


@app.get("/models/{file_id}/details", tags=["模型详情"])
async def get_model_details(file_id: str):
    """
    获取零件详细信息（LLM增强）
    
    包括：技术描述、应用场景、选型建议等
    """
    if not search_engine:
        raise HTTPException(status_code=500, detail="搜索引擎未初始化")
    
    details = search_engine.get_part_details(file_id)
    
    if not details:
        raise HTTPException(status_code=404, detail="模型不存在")
    
    return details


@app.post("/query/expand", tags=["查询工具"])
async def expand_query(query: str = Query(..., description="查询词")):
    """
    查询扩展
    
    使用LLM生成同义词和相关术语
    """
    if not llm_analyzer or not llm_analyzer.provider.is_available():
        return {"original": query, "expanded": [query]}
    
    expanded = llm_analyzer.expand_query(query)
    
    return {
        "original": query,
        "expanded": expanded
    }


@app.post("/query/understand", tags=["查询工具"])
async def understand_query(query: str = Query(..., description="查询词")):
    """
    查询意图理解
    
    分析用户搜索意图
    """
    if not llm_analyzer or not llm_analyzer.provider.is_available():
        return {"query": query, "understanding": {}}
    
    understanding = llm_analyzer.understand_query(query)
    
    return {
        "query": query,
        "understanding": understanding
    }


@app.get("/stats", tags=["统计"])
async def get_stats():
    """获取系统统计信息"""
    if not search_engine:
        return {"error": "引擎未初始化"}
    
    part_type_counts = {}
    for part_type, ids in search_engine.part_type_index.items():
        part_type_counts[part_type] = len(ids)
    
    return {
        "total_models": len(search_engine.models),
        "part_types": part_type_counts,
        "total_keywords": len(search_engine.keyword_index),
        "llm_available": llm_analyzer.provider.is_available() if llm_analyzer else False,
        "llm_provider": search_engine.llm_config.provider if search_engine else "N/A"
    }


@app.post("/config/llm", tags=["配置"])
async def update_llm_config(config: LLMConfigModel):
    """更新LLM配置"""
    global llm_analyzer, search_engine
    
    new_config = LLMConfig(
        provider=config.provider,
        model=config.model,
        api_key=config.api_key,
        base_url=config.base_url,
        temperature=config.temperature
    )
    
    llm_analyzer = LLMEnhancedAnalyzer(new_config)
    search_engine.llm_config = new_config
    search_engine.llm_analyzer = llm_analyzer
    search_engine.llm_available = llm_analyzer.provider.is_available()
    
    return {
        "success": True,
        "llm_available": search_engine.llm_available,
        "provider": config.provider
    }


# ============== Dify 集成 ==============

@app.post("/dify/smart_search", tags=["Dify集成"])
async def dify_smart_search(
    query: str = Query(..., description="搜索关键词"),
    limit: int = Query(default=5, ge=1, le=20)
):
    """
    Dify工作流智能搜索接口
    
    返回格式化的搜索结果，适合LLM处理
    """
    results = search_engine.smart_search(query=query, top_k=limit, explain=True)
    
    # 格式化输出
    output_parts = [f"## 搜索结果: {query}\n"]
    
    if not results:
        output_parts.append("未找到匹配的零件。")
    else:
        for i, r in enumerate(results, 1):
            part = f"""### {i}. {r.filename}
- **类型**: {r.part_type}
- **匹配度**: {r.score:.1%}
- **描述**: {r.description[:150]}...
"""
            if r.llm_explanation:
                part += f"- **匹配原因**: {r.llm_explanation}\n"
            if r.related_parts:
                part += f"- **相关零件**: {', '.join(r.related_parts)}\n"
            
            output_parts.append(part)
    
    return {
        "success": True,
        "query": query,
        "count": len(results),
        "formatted_output": "\n".join(output_parts),
        "raw_results": [
            {
                "file_id": r.file_id,
                "filename": r.filename,
                "part_type": r.part_type,
                "score": r.score,
                "description": r.description,
                "explanation": r.llm_explanation
            } for r in results
        ]
    }


@app.post("/dify/chat", tags=["Dify集成"])
async def dify_chat(
    message: str = Query(..., description="用户消息"),
    context: str = Query(default="", description="上下文信息")
):
    """
    Dify对话式搜索接口
    """
    history = []
    if context:
        try:
            history = json.loads(context)
        except:
            pass
    
    result = search_engine.conversational_search(message, history)
    
    # 格式化响应
    response_text = result["message"]
    
    if result["results"]:
        response_text += "\n\n### 找到的零件:\n"
        for i, r in enumerate(result["results"][:5], 1):
            response_text += f"{i}. **{r.filename}** - {r.part_type}\n"
    
    if result["suggestions"]:
        response_text += "\n\n### 您可能还想问:\n"
        for s in result["suggestions"][:3]:
            response_text += f"- {s}\n"
    
    return {
        "response": response_text,
        "has_results": len(result["results"]) > 0,
        "suggestions": result["suggestions"]
    }


# ============== 启动事件 ==============

@app.on_event("startup")
async def startup_event():
    """服务启动时初始化"""
    init_engine()
    
    print("\n" + "=" * 50)
    print("🚀 三维模型智能搜索API已启动")
    print("=" * 50)
    print(f"📊 已加载模型: {len(search_engine.models)}")
    print(f"🤖 LLM状态: {'可用' if search_engine.llm_available else '不可用'}")
    if search_engine.llm_available:
        print(f"   提供者: {search_engine.llm_config.provider}")
    print(f"📖 API文档: http://localhost:8000/docs")
    print("=" * 50 + "\n")


# ============== 主入口 ==============

def main():
    """运行API服务"""
    uvicorn.run(
        "api_server_enhanced:app",
        host="0.0.0.0",
        port=8000,
        reload=True
    )


if __name__ == "__main__":
    main()
