"""
三维模型语义搜索引擎
支持关键词搜索、语义向量搜索和混合检索
"""

import json
import re
import math
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from collections import defaultdict
import numpy as np
from pathlib import Path

# 简化的文本向量化（生产环境建议使用sentence-transformers或OpenAI embeddings）
class SimpleTextVectorizer:
    """简单的TF-IDF向量化器"""
    
    def __init__(self):
        self.vocabulary = {}
        self.idf = {}
        self.vocab_size = 0
        
    def fit(self, documents: List[str]):
        """构建词汇表和IDF"""
        # 分词
        doc_tokens = [self._tokenize(doc) for doc in documents]
        
        # 构建词汇表
        all_tokens = set()
        for tokens in doc_tokens:
            all_tokens.update(tokens)
        
        self.vocabulary = {token: idx for idx, token in enumerate(sorted(all_tokens))}
        self.vocab_size = len(self.vocabulary)
        
        # 计算IDF
        doc_count = len(documents)
        doc_freq = defaultdict(int)
        for tokens in doc_tokens:
            unique_tokens = set(tokens)
            for token in unique_tokens:
                doc_freq[token] += 1
        
        for token, freq in doc_freq.items():
            self.idf[token] = math.log(doc_count / (1 + freq)) + 1
    
    def _tokenize(self, text: str) -> List[str]:
        """中文分词（简单实现）"""
        # 移除标点
        text = re.sub(r'[^\w\s]', ' ', text)
        # 分割（中文按字符，英文按空格）
        tokens = []
        current_word = ""
        for char in text:
            if '\u4e00' <= char <= '\u9fff':
                if current_word:
                    tokens.append(current_word.lower())
                    current_word = ""
                tokens.append(char)
            elif char.isalnum():
                current_word += char
            else:
                if current_word:
                    tokens.append(current_word.lower())
                    current_word = ""
        if current_word:
            tokens.append(current_word.lower())
        
        # 添加n-gram
        ngrams = []
        for i in range(len(tokens) - 1):
            ngrams.append(tokens[i] + tokens[i+1])
        
        return tokens + ngrams
    
    def transform(self, text: str) -> np.ndarray:
        """转换文本为向量"""
        if self.vocab_size == 0:
            return np.array([])
        
        tokens = self._tokenize(text)
        vector = np.zeros(self.vocab_size)
        
        # 计算TF
        tf = defaultdict(int)
        for token in tokens:
            tf[token] += 1
        
        # TF-IDF
        for token, count in tf.items():
            if token in self.vocabulary:
                idx = self.vocabulary[token]
                tf_val = 1 + math.log(count) if count > 0 else 0
                idf_val = self.idf.get(token, 1)
                vector[idx] = tf_val * idf_val
        
        # L2归一化
        norm = np.linalg.norm(vector)
        if norm > 0:
            vector = vector / norm
        
        return vector


@dataclass
class SearchResult:
    """搜索结果"""
    file_id: str
    filename: str
    part_type: str
    score: float
    match_type: str  # keyword, semantic, hybrid
    highlights: List[str]
    description: str


class CADModelSearchEngine:
    """三维模型搜索引擎"""
    
    def __init__(self):
        self.models = {}  # file_id -> semantic data
        self.vectors = {}  # file_id -> vector
        self.vectorizer = SimpleTextVectorizer()
        self.keyword_index = defaultdict(set)  # keyword -> set of file_ids
        self.part_type_index = defaultdict(set)  # part_type -> set of file_ids
        
        # 同义词/别名映射
        self.synonyms = {
            "三通": ["T型", "T接头", "分流接头", "三向接头", "tee"],
            "四通": ["十字接头", "cross", "四向接头", "交叉接头"],
            "螺钉": ["螺栓", "螺丝", "screw", "bolt"],
            "螺母": ["nut", "六角母"],
            "法兰": ["flange", "连接盘"],
            "弯头": ["弯管", "elbow", "90度弯头", "45度弯头"],
            "垫圈": ["垫片", "washer"],
            "内六角": ["allen", "hex", "六角"],
            "喉塞": ["堵头", "plug", "闷头"],
        }
        
        # 构建反向同义词索引
        self.synonym_reverse = {}
        for main_term, synonyms in self.synonyms.items():
            self.synonym_reverse[main_term.lower()] = main_term
            for syn in synonyms:
                self.synonym_reverse[syn.lower()] = main_term
    
    def index_model(self, semantic_data: Dict[str, Any]):
        """索引单个模型"""
        file_id = semantic_data["file_id"]
        self.models[file_id] = semantic_data
        
        # 关键词索引
        for keyword in semantic_data.get("keywords", []):
            self.keyword_index[keyword.lower()].add(file_id)
        
        # 零件类型索引
        part_type = semantic_data.get("part_type", "")
        if part_type:
            self.part_type_index[part_type].add(file_id)
    
    def build_vector_index(self):
        """构建向量索引"""
        documents = []
        file_ids = []
        
        for file_id, data in self.models.items():
            documents.append(data.get("searchable_text", ""))
            file_ids.append(file_id)
        
        if not documents:
            return
        
        # 训练向量化器
        self.vectorizer.fit(documents)
        
        # 生成向量
        for file_id, doc in zip(file_ids, documents):
            self.vectors[file_id] = self.vectorizer.transform(doc)
    
    def expand_query(self, query: str) -> List[str]:
        """查询扩展（同义词扩展）"""
        expanded = [query]
        query_lower = query.lower()
        
        # 检查是否有同义词
        if query_lower in self.synonym_reverse:
            main_term = self.synonym_reverse[query_lower]
            expanded.append(main_term)
            expanded.extend(self.synonyms.get(main_term, []))
        
        # 检查查询中是否包含同义词
        for term, synonyms in self.synonyms.items():
            if term in query:
                expanded.extend(synonyms)
            for syn in synonyms:
                if syn.lower() in query_lower:
                    expanded.append(term)
                    expanded.extend(synonyms)
        
        return list(set(expanded))
    
    def keyword_search(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
        """关键词搜索"""
        # 扩展查询
        expanded_queries = self.expand_query(query)
        
        scores = defaultdict(float)
        
        for q in expanded_queries:
            q_lower = q.lower()
            
            # 精确匹配关键词
            if q_lower in self.keyword_index:
                for file_id in self.keyword_index[q_lower]:
                    scores[file_id] += 2.0
            
            # 部分匹配
            for keyword, file_ids in self.keyword_index.items():
                if q_lower in keyword or keyword in q_lower:
                    for file_id in file_ids:
                        scores[file_id] += 1.0
            
            # 零件类型匹配
            for part_type, file_ids in self.part_type_index.items():
                if q_lower in part_type.lower() or part_type.lower() in q_lower:
                    for file_id in file_ids:
                        scores[file_id] += 1.5
            
            # 文件名匹配
            for file_id, data in self.models.items():
                filename = data.get("filename", "").lower()
                if q_lower in filename:
                    scores[file_id] += 2.5
        
        # 排序并返回top_k
        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_results[:top_k]
    
    def semantic_search(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
        """语义向量搜索"""
        if not self.vectors:
            return []
        
        # 查询向量化
        query_vector = self.vectorizer.transform(query)
        
        if query_vector.size == 0:
            return []
        
        # 计算余弦相似度
        scores = {}
        for file_id, doc_vector in self.vectors.items():
            if doc_vector.size > 0:
                similarity = np.dot(query_vector, doc_vector)
                scores[file_id] = float(similarity)
        
        # 排序
        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_results[:top_k]
    
    def hybrid_search(self, query: str, top_k: int = 10, 
                      keyword_weight: float = 0.6,
                      semantic_weight: float = 0.4) -> List[SearchResult]:
        """混合搜索（关键词 + 语义）"""
        # 获取两种搜索结果
        keyword_results = self.keyword_search(query, top_k * 2)
        semantic_results = self.semantic_search(query, top_k * 2)
        
        # 归一化分数
        def normalize_scores(results):
            if not results:
                return {}
            max_score = max(r[1] for r in results) if results else 1
            return {r[0]: r[1] / max_score if max_score > 0 else 0 for r in results}
        
        keyword_scores = normalize_scores(keyword_results)
        semantic_scores = normalize_scores(semantic_results)
        
        # 合并分数
        all_ids = set(keyword_scores.keys()) | set(semantic_scores.keys())
        combined_scores = {}
        
        for file_id in all_ids:
            kw_score = keyword_scores.get(file_id, 0)
            sem_score = semantic_scores.get(file_id, 0)
            combined_scores[file_id] = (keyword_weight * kw_score + 
                                        semantic_weight * sem_score)
        
        # 排序
        sorted_ids = sorted(combined_scores.keys(), 
                           key=lambda x: combined_scores[x], 
                           reverse=True)[:top_k]
        
        # 构建结果
        results = []
        expanded_queries = self.expand_query(query)
        
        for file_id in sorted_ids:
            data = self.models.get(file_id, {})
            
            # 生成高亮片段
            highlights = []
            for kw in data.get("keywords", []):
                for q in expanded_queries:
                    if q.lower() in kw.lower() or kw.lower() in q.lower():
                        highlights.append(f"关键词匹配: {kw}")
                        break
            
            if data.get("part_type", ""):
                for q in expanded_queries:
                    if q.lower() in data["part_type"].lower():
                        highlights.append(f"类型匹配: {data['part_type']}")
                        break
            
            # 确定匹配类型
            kw_score = keyword_scores.get(file_id, 0)
            sem_score = semantic_scores.get(file_id, 0)
            if kw_score > 0 and sem_score > 0:
                match_type = "hybrid"
            elif kw_score > 0:
                match_type = "keyword"
            else:
                match_type = "semantic"
            
            results.append(SearchResult(
                file_id=file_id,
                filename=data.get("filename", ""),
                part_type=data.get("part_type", ""),
                score=combined_scores[file_id],
                match_type=match_type,
                highlights=highlights[:5],
                description=data.get("description", "")[:200]
            ))
        
        return results
    
    def search(self, query: str, top_k: int = 10, 
               method: str = "hybrid") -> List[SearchResult]:
        """统一搜索接口"""
        if method == "keyword":
            results = self.keyword_search(query, top_k)
            return [SearchResult(
                file_id=r[0],
                filename=self.models.get(r[0], {}).get("filename", ""),
                part_type=self.models.get(r[0], {}).get("part_type", ""),
                score=r[1],
                match_type="keyword",
                highlights=[],
                description=self.models.get(r[0], {}).get("description", "")[:200]
            ) for r in results]
        elif method == "semantic":
            results = self.semantic_search(query, top_k)
            return [SearchResult(
                file_id=r[0],
                filename=self.models.get(r[0], {}).get("filename", ""),
                part_type=self.models.get(r[0], {}).get("part_type", ""),
                score=r[1],
                match_type="semantic",
                highlights=[],
                description=self.models.get(r[0], {}).get("description", "")[:200]
            ) for r in results]
        else:
            return self.hybrid_search(query, top_k)
    
    def save_index(self, filepath: str):
        """保存索引"""
        data = {
            "models": self.models,
            "vocabulary": self.vectorizer.vocabulary,
            "idf": self.vectorizer.idf,
        }
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
    
    def load_index(self, filepath: str):
        """加载索引"""
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.models = data.get("models", {})
        self.vectorizer.vocabulary = data.get("vocabulary", {})
        self.vectorizer.idf = data.get("idf", {})
        self.vectorizer.vocab_size = len(self.vectorizer.vocabulary)
        
        # 重建索引
        for file_id, model_data in self.models.items():
            for keyword in model_data.get("keywords", []):
                self.keyword_index[keyword.lower()].add(file_id)
            part_type = model_data.get("part_type", "")
            if part_type:
                self.part_type_index[part_type].add(file_id)
        
        # 重建向量
        for file_id, model_data in self.models.items():
            text = model_data.get("searchable_text", "")
            self.vectors[file_id] = self.vectorizer.transform(text)


def create_search_index_from_json_files(json_files: List[str]) -> CADModelSearchEngine:
    """从JSON文件创建搜索索引"""
    from feature_semantic_analyzer import FeatureSemanticAnalyzer
    
    analyzer = FeatureSemanticAnalyzer()
    engine = CADModelSearchEngine()
    
    for json_file in json_files:
        try:
            semantics = analyzer.analyze_json(json_file)
            semantic_dict = analyzer.to_dict(semantics)
            engine.index_model(semantic_dict)
            print(f"✓ 索引: {semantics.filename} -> {semantics.part_type}")
        except Exception as e:
            print(f"✗ 索引失败: {json_file} - {e}")
    
    engine.build_vector_index()
    return engine


# 测试代码
if __name__ == "__main__":
    import sys
    sys.path.insert(0, '/home/claude/cad_semantic_search')
    
    from feature_semantic_analyzer import FeatureSemanticAnalyzer
    
    # 测试文件
    test_files = [
        "/mnt/user-data/uploads/AGES-50_features.json",
        "/mnt/user-data/uploads/prt0001_features.json",
        "/mnt/user-data/uploads/凹端内六角紧定螺钉_PSEP-6-16-A__features.json",
        "/mnt/user-data/uploads/内六角喉塞_PSEG-M22-A__features.json",
        "/mnt/user-data/uploads/内六角机螺钉_内六角平头带垫螺栓_PSHFP-M6-25-A__features.json",
    ]
    
    # 创建索引
    print("=" * 60)
    print("创建搜索索引")
    print("=" * 60)
    
    engine = create_search_index_from_json_files(test_files)
    
    # 测试搜索
    print("\n" + "=" * 60)
    print("搜索测试")
    print("=" * 60)
    
    test_queries = ["三通", "螺钉", "内六角", "四通", "喉塞", "交叉孔"]
    
    for query in test_queries:
        print(f"\n🔍 搜索: '{query}'")
        print("-" * 40)
        
        results = engine.search(query, top_k=5)
        
        if results:
            for i, result in enumerate(results, 1):
                print(f"  {i}. {result.filename}")
                print(f"     类型: {result.part_type} | 分数: {result.score:.3f} | 匹配: {result.match_type}")
                if result.highlights:
                    print(f"     高亮: {', '.join(result.highlights[:3])}")
        else:
            print("  无匹配结果")
