"""
Neo4j知识图谱集成模块
将CAD特征数据导入Neo4j，支持复杂拓扑查询
"""

from typing import List, Dict, Any, Optional
import json
from dataclasses import dataclass

# Neo4j驱动导入（需要安装: pip install neo4j）
try:
    from neo4j import GraphDatabase
    NEO4J_AVAILABLE = True
except ImportError:
    NEO4J_AVAILABLE = False
    print("警告: neo4j驱动未安装，知识图谱功能不可用")


@dataclass
class Neo4jConfig:
    """Neo4j连接配置"""
    uri: str = "bolt://localhost:7687"
    username: str = "neo4j"
    password: str = "password"
    database: str = "neo4j"


class CADKnowledgeGraph:
    """CAD模型知识图谱管理器"""
    
    def __init__(self, config: Neo4jConfig = None):
        self.config = config or Neo4jConfig()
        self.driver = None
        
        if NEO4J_AVAILABLE:
            try:
                self.driver = GraphDatabase.driver(
                    self.config.uri,
                    auth=(self.config.username, self.config.password)
                )
            except Exception as e:
                print(f"Neo4j连接失败: {e}")
    
    def close(self):
        if self.driver:
            self.driver.close()
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
    
    def setup_schema(self):
        """创建图数据库Schema和索引"""
        if not self.driver:
            return
        
        with self.driver.session(database=self.config.database) as session:
            # 创建约束和索引
            constraints = [
                "CREATE CONSTRAINT IF NOT EXISTS FOR (m:Model) REQUIRE m.file_id IS UNIQUE",
                "CREATE CONSTRAINT IF NOT EXISTS FOR (f:Feature) REQUIRE f.feature_id IS UNIQUE",
                "CREATE INDEX IF NOT EXISTS FOR (m:Model) ON (m.part_type)",
                "CREATE INDEX IF NOT EXISTS FOR (f:Feature) ON (f.feature_type)",
                "CREATE INDEX IF NOT EXISTS FOR (m:Model) ON (m.filename)",
            ]
            
            for constraint in constraints:
                try:
                    session.run(constraint)
                except Exception as e:
                    print(f"Schema操作: {e}")
    
    def import_model(self, semantic_data: Dict[str, Any], features_data: Dict[str, Any]):
        """导入单个模型到知识图谱"""
        if not self.driver:
            return
        
        file_id = semantic_data.get("file_id", "")
        
        with self.driver.session(database=self.config.database) as session:
            # 创建模型节点
            session.run("""
                MERGE (m:Model {file_id: $file_id})
                SET m.filename = $filename,
                    m.part_type = $part_type,
                    m.description = $description,
                    m.keywords = $keywords,
                    m.searchable_text = $searchable_text
            """, {
                "file_id": file_id,
                "filename": semantic_data.get("filename", ""),
                "part_type": semantic_data.get("part_type", ""),
                "description": semantic_data.get("description", ""),
                "keywords": semantic_data.get("keywords", []),
                "searchable_text": semantic_data.get("searchable_text", "")
            })
            
            # 创建特征节点
            features = features_data.get("features", [])
            for idx, feat_str in enumerate(features):
                feature_id = f"{file_id}_feat_{idx}"
                
                # 解析特征类型
                feat_type = "未知"
                if ":" in feat_str:
                    feat_type = feat_str.split(":")[0].replace("特征", "")
                
                session.run("""
                    MERGE (f:Feature {feature_id: $feature_id})
                    SET f.feature_type = $feature_type,
                        f.raw_data = $raw_data,
                        f.index = $index
                    WITH f
                    MATCH (m:Model {file_id: $file_id})
                    MERGE (m)-[:HAS_FEATURE]->(f)
                """, {
                    "feature_id": feature_id,
                    "feature_type": feat_type,
                    "raw_data": feat_str,
                    "index": idx,
                    "file_id": file_id
                })
            
            # 创建拓扑关系
            topology_relations = features_data.get("topology_relations", [])
            for rel in topology_relations:
                f1_id = f"{file_id}_feat_{rel['feature1_index']}"
                f2_id = f"{file_id}_feat_{rel['feature2_index']}"
                rel_type = rel.get("relation_type", "RELATED")
                
                # 将中文关系类型转换为英文标签
                rel_type_map = {
                    "相邻": "ADJACENT",
                    "同轴": "COAXIAL",
                    "平行": "PARALLEL",
                    "垂直": "PERPENDICULAR",
                    "同心": "CONCENTRIC",
                }
                neo4j_rel_type = rel_type_map.get(rel_type, "RELATED")
                
                session.run(f"""
                    MATCH (f1:Feature {{feature_id: $f1_id}})
                    MATCH (f2:Feature {{feature_id: $f2_id}})
                    MERGE (f1)-[r:{neo4j_rel_type}]->(f2)
                    SET r.details = $details
                """, {
                    "f1_id": f1_id,
                    "f2_id": f2_id,
                    "details": json.dumps(rel.get("details", {}))
                })
            
            # 创建特征模式关系
            patterns = features_data.get("feature_patterns", [])
            for pat in patterns:
                pat_type = pat.get("type", "")
                if pat_type == "镜像对称" and "features" in pat:
                    for feat_idx in pat["features"]:
                        feat_id = f"{file_id}_feat_{feat_idx}"
                        session.run("""
                            MATCH (f:Feature {feature_id: $feat_id})
                            SET f.symmetry_plane = $plane
                        """, {
                            "feat_id": feat_id,
                            "plane": pat.get("plane", "")
                        })
    
    def search_by_part_type(self, part_type: str, limit: int = 10) -> List[Dict]:
        """按零件类型搜索"""
        if not self.driver:
            return []
        
        with self.driver.session(database=self.config.database) as session:
            result = session.run("""
                MATCH (m:Model)
                WHERE m.part_type CONTAINS $part_type
                RETURN m.file_id as file_id, 
                       m.filename as filename,
                       m.part_type as part_type,
                       m.description as description
                LIMIT $limit
            """, {"part_type": part_type, "limit": limit})
            
            return [dict(record) for record in result]
    
    def search_by_topology(self, feature_type: str, relation_type: str, 
                           min_count: int = 1, limit: int = 10) -> List[Dict]:
        """
        按拓扑特征搜索
        
        例如: 搜索有3个以上交叉孔的模型
        """
        if not self.driver:
            return []
        
        rel_type_map = {
            "相邻": "ADJACENT",
            "同轴": "COAXIAL", 
            "平行": "PARALLEL",
        }
        neo4j_rel = rel_type_map.get(relation_type, relation_type)
        
        with self.driver.session(database=self.config.database) as session:
            result = session.run(f"""
                MATCH (m:Model)-[:HAS_FEATURE]->(f:Feature)
                WHERE f.feature_type = $feature_type
                WITH m, count(f) as feat_count
                WHERE feat_count >= $min_count
                RETURN m.file_id as file_id,
                       m.filename as filename,
                       m.part_type as part_type,
                       feat_count
                ORDER BY feat_count DESC
                LIMIT $limit
            """, {
                "feature_type": feature_type,
                "min_count": min_count,
                "limit": limit
            })
            
            return [dict(record) for record in result]
    
    def find_similar_topology(self, file_id: str, limit: int = 5) -> List[Dict]:
        """
        查找拓扑结构相似的模型
        
        基于特征类型分布和关系模式
        """
        if not self.driver:
            return []
        
        with self.driver.session(database=self.config.database) as session:
            # 获取目标模型的特征分布
            result = session.run("""
                MATCH (m:Model {file_id: $file_id})-[:HAS_FEATURE]->(f:Feature)
                RETURN f.feature_type as feature_type, count(*) as count
            """, {"file_id": file_id})
            
            target_features = {r["feature_type"]: r["count"] for r in result}
            
            if not target_features:
                return []
            
            # 查找相似模型
            result = session.run("""
                MATCH (m:Model)-[:HAS_FEATURE]->(f:Feature)
                WHERE m.file_id <> $file_id
                WITH m, f.feature_type as ft, count(*) as cnt
                WITH m, collect({type: ft, count: cnt}) as features
                RETURN m.file_id as file_id,
                       m.filename as filename,
                       m.part_type as part_type,
                       features
                LIMIT $limit
            """, {"file_id": file_id, "limit": limit * 3})
            
            # 计算相似度
            candidates = []
            for record in result:
                other_features = {f["type"]: f["count"] for f in record["features"]}
                
                # 简单的Jaccard相似度
                all_types = set(target_features.keys()) | set(other_features.keys())
                intersection = sum(
                    min(target_features.get(t, 0), other_features.get(t, 0))
                    for t in all_types
                )
                union = sum(
                    max(target_features.get(t, 0), other_features.get(t, 0))
                    for t in all_types
                )
                
                similarity = intersection / union if union > 0 else 0
                
                candidates.append({
                    "file_id": record["file_id"],
                    "filename": record["filename"],
                    "part_type": record["part_type"],
                    "similarity": similarity
                })
            
            # 排序返回
            candidates.sort(key=lambda x: x["similarity"], reverse=True)
            return candidates[:limit]
    
    def get_model_graph(self, file_id: str) -> Dict:
        """获取模型的完整图结构"""
        if not self.driver:
            return {}
        
        with self.driver.session(database=self.config.database) as session:
            # 获取模型节点
            model_result = session.run("""
                MATCH (m:Model {file_id: $file_id})
                RETURN m
            """, {"file_id": file_id})
            
            model_record = model_result.single()
            if not model_record:
                return {}
            
            model_node = dict(model_record["m"])
            
            # 获取特征节点
            features_result = session.run("""
                MATCH (m:Model {file_id: $file_id})-[:HAS_FEATURE]->(f:Feature)
                RETURN f
                ORDER BY f.index
            """, {"file_id": file_id})
            
            features = [dict(r["f"]) for r in features_result]
            
            # 获取特征间关系
            relations_result = session.run("""
                MATCH (m:Model {file_id: $file_id})-[:HAS_FEATURE]->(f1:Feature)
                MATCH (f1)-[r]->(f2:Feature)
                WHERE (m)-[:HAS_FEATURE]->(f2)
                RETURN f1.feature_id as from_id,
                       type(r) as relation_type,
                       f2.feature_id as to_id,
                       r.details as details
            """, {"file_id": file_id})
            
            relations = [dict(r) for r in relations_result]
            
            return {
                "model": model_node,
                "features": features,
                "relations": relations
            }
    
    def cypher_query(self, query: str, params: Dict = None) -> List[Dict]:
        """执行自定义Cypher查询"""
        if not self.driver:
            return []
        
        with self.driver.session(database=self.config.database) as session:
            result = session.run(query, params or {})
            return [dict(record) for record in result]
    
    def get_statistics(self) -> Dict:
        """获取知识图谱统计信息"""
        if not self.driver:
            return {}
        
        with self.driver.session(database=self.config.database) as session:
            stats = {}
            
            # 模型数量
            result = session.run("MATCH (m:Model) RETURN count(m) as count")
            stats["total_models"] = result.single()["count"]
            
            # 特征数量
            result = session.run("MATCH (f:Feature) RETURN count(f) as count")
            stats["total_features"] = result.single()["count"]
            
            # 关系数量
            result = session.run("MATCH ()-[r]->() RETURN count(r) as count")
            stats["total_relations"] = result.single()["count"]
            
            # 按类型统计
            result = session.run("""
                MATCH (m:Model)
                RETURN m.part_type as part_type, count(*) as count
                ORDER BY count DESC
            """)
            stats["models_by_type"] = {r["part_type"]: r["count"] for r in result}
            
            result = session.run("""
                MATCH (f:Feature)
                RETURN f.feature_type as feature_type, count(*) as count
                ORDER BY count DESC
            """)
            stats["features_by_type"] = {r["feature_type"]: r["count"] for r in result}
            
            return stats


# ============== 示例Cypher查询模板 ==============

CYPHER_TEMPLATES = {
    "find_pipe_fittings": """
        // 查找管件（三通、四通等）
        MATCH (m:Model)-[:HAS_FEATURE]->(f:Feature {feature_type: '孔'})
        WHERE f.raw_data CONTAINS '交叉孔'
        WITH m, count(f) as cross_hole_count
        WHERE cross_hole_count >= 2
        RETURN m.filename, m.part_type, cross_hole_count
        ORDER BY cross_hole_count DESC
    """,
    
    "find_threaded_parts": """
        // 查找螺纹零件
        MATCH (m:Model)
        WHERE any(kw IN m.keywords WHERE kw CONTAINS '螺' OR kw CONTAINS '内六角')
        RETURN m.filename, m.part_type, m.keywords
    """,
    
    "find_symmetric_models": """
        // 查找对称零件
        MATCH (m:Model)-[:HAS_FEATURE]->(f:Feature)
        WHERE f.symmetry_plane IS NOT NULL
        WITH m, collect(DISTINCT f.symmetry_plane) as symmetry_planes
        WHERE size(symmetry_planes) >= 2
        RETURN m.filename, symmetry_planes
    """,
    
    "topology_analysis": """
        // 分析特征拓扑关系
        MATCH (m:Model {file_id: $file_id})-[:HAS_FEATURE]->(f1:Feature)
        MATCH (f1)-[r]->(f2:Feature)
        RETURN f1.feature_type as from_type,
               type(r) as relation,
               f2.feature_type as to_type,
               count(*) as count
        ORDER BY count DESC
    """,
}


# ============== 使用示例 ==============

def example_usage():
    """使用示例"""
    config = Neo4jConfig(
        uri="bolt://localhost:7687",
        username="neo4j",
        password="your_password"
    )
    
    with CADKnowledgeGraph(config) as kg:
        # 设置Schema
        kg.setup_schema()
        
        # 搜索三通
        results = kg.search_by_part_type("三通")
        for r in results:
            print(f"找到: {r['filename']} - {r['part_type']}")
        
        # 按拓扑搜索
        results = kg.search_by_topology(
            feature_type="孔",
            relation_type="相邻",
            min_count=3
        )
        
        # 获取统计
        stats = kg.get_statistics()
        print(f"总模型数: {stats['total_models']}")


if __name__ == "__main__":
    print("Neo4j知识图谱模块")
    print("=" * 50)
    
    if not NEO4J_AVAILABLE:
        print("请安装neo4j驱动: pip install neo4j")
    else:
        print("Neo4j驱动已就绪")
        print("\n示例Cypher查询模板:")
        for name, query in CYPHER_TEMPLATES.items():
            print(f"\n--- {name} ---")
            print(query[:200] + "..." if len(query) > 200 else query)
