"""
三维模型特征语义分析器
将JSON特征转换为可搜索的语义描述
"""

import json
import re
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import math


@dataclass
class FeatureSemantics:
    """特征语义信息"""
    file_id: str
    filename: str
    part_type: str  # 推断的零件类型
    description: str  # 自然语言描述
    keywords: List[str]  # 关键词
    feature_summary: Dict[str, Any]  # 特征统计摘要
    topology_signature: str  # 拓扑特征签名
    searchable_text: str  # 用于全文搜索的文本


class FeatureSemanticAnalyzer:
    """特征语义分析器 - 将CAD特征转换为可搜索的语义描述"""
    
    # 零件类型推断规则
    PART_TYPE_RULES = {
        "螺钉": {
            "keywords": ["螺钉", "螺栓", "紧定"],
            "features": {"孔": (1, 2), "倒角": (4, 10), "圆锥面": (1, 3)},
            "patterns": ["内六角"]
        },
        "螺母": {
            "keywords": ["螺母", "六角"],
            "features": {"孔": (1, 1), "平面": (6, 8)},
            "patterns": ["六角"]
        },
        "喉塞": {
            "keywords": ["喉塞", "堵头"],
            "features": {"孔": (1, 1), "凸台": (1, 3), "平面": (6, 8)},
            "patterns": ["内六角"]
        },
        "三通": {
            "keywords": ["三通", "T型接头"],
            "features": {"孔": (3, 6), "凸台": (3, 4)},
            "patterns": ["交叉孔"],
            "axis_count": 3  # 三个不同轴向的孔
        },
        "四通": {
            "keywords": ["四通", "十字接头", "cross"],
            "features": {"孔": (4, 8), "凸台": (4, 6)},
            "patterns": ["交叉孔"],
            "axis_count": 4  # 四个不同轴向的孔
        },
        "弯头": {
            "keywords": ["弯头", "弯管", "elbow"],
            "features": {"孔": (2, 4)},
            "axis_count": 2,
            "angle_range": (45, 135)
        },
        "法兰": {
            "keywords": ["法兰", "flange"],
            "features": {"孔": (4, 20), "平面": (2, 4)},
            "patterns": ["圆周阵列"]
        },
        "轴承座": {
            "keywords": ["轴承座", "bearing"],
            "features": {"孔": (2, 6), "平面": (4, 10)},
        },
        "板材": {
            "keywords": ["板", "plate", "panel"],
            "features": {"平面": (2, 2)},
            "patterns": ["平行平面"]
        },
    }
    
    # 特征类型中文映射
    FEATURE_TYPE_MAP = {
        "孔": "hole",
        "平面": "plane",
        "凸台": "boss",
        "倒角": "chamfer",
        "圆角": "fillet",
        "圆锥面": "cone",
        "B样条曲面": "bspline",
        "圆柱面": "cylinder",
    }
    
    # 孔类型语义
    HOLE_TYPE_SEMANTICS = {
        "通孔": "贯穿的通道孔",
        "盲孔": "有底的定位孔或装配孔",
        "交叉孔": "多向交汇的管道孔",
        "沉头孔": "带有沉台的螺钉安装孔",
        "锥形孔": "带锥度的导向孔",
    }
    
    def __init__(self):
        self.analyzed_parts = {}
    
    def parse_feature_string(self, feature_str: str) -> Dict[str, Any]:
        """解析特征字符串为结构化数据"""
        result = {"raw": feature_str}
        
        # 提取特征类型
        type_match = re.match(r'^(\w+)特征:', feature_str)
        if type_match:
            result["type"] = type_match.group(1)
        
        # 提取数值参数
        patterns = {
            "半径": r'半径:\s*([\d.e+-]+)',
            "直径": r'直径:\s*([\d.e+-]+)',
            "深度": r'深度:\s*([\d.e+-]+)',
            "面积": r'面积:\s*([\d.e+-]+)',
            "高度": r'高度:\s*([\d.e+-]+)',
            "位置": r'位置:\s*\(([\d.e+-]+),\s*([\d.e+-]+),\s*([\d.e+-]+)\)',
            "轴向": r'轴向:\s*\(([\d.e+-]+),\s*([\d.e+-]+),\s*([\d.e+-]+)\)',
            "方向": r'方向:\s*\(([\d.e+-]+),\s*([\d.e+-]+),\s*([\d.e+-]+)\)',
            "类型": r'类型:\s*(\w+)',
        }
        
        for key, pattern in patterns.items():
            match = re.search(pattern, feature_str)
            if match:
                if key in ["位置", "轴向", "方向"]:
                    result[key] = tuple(float(x) for x in match.groups())
                elif key == "类型":
                    result[key] = match.group(1)
                else:
                    try:
                        result[key] = float(match.group(1))
                    except:
                        result[key] = match.group(1)
        
        return result
    
    def extract_feature_statistics(self, features: List[str]) -> Dict[str, Any]:
        """提取特征统计信息"""
        stats = {
            "total_count": len(features),
            "by_type": {},
            "hole_types": {},
            "dimensions": {
                "max_diameter": 0,
                "min_diameter": float('inf'),
                "diameters": [],
                "depths": [],
            },
            "axes": set(),  # 不同的轴向
        }
        
        for feat_str in features:
            parsed = self.parse_feature_string(feat_str)
            feat_type = parsed.get("type", "未知")
            
            # 统计特征类型
            stats["by_type"][feat_type] = stats["by_type"].get(feat_type, 0) + 1
            
            # 孔特征详细统计
            if feat_type == "孔":
                hole_type = parsed.get("类型", "未知")
                stats["hole_types"][hole_type] = stats["hole_types"].get(hole_type, 0) + 1
                
                if "直径" in parsed:
                    d = parsed["直径"]
                    stats["dimensions"]["diameters"].append(d)
                    stats["dimensions"]["max_diameter"] = max(stats["dimensions"]["max_diameter"], d)
                    stats["dimensions"]["min_diameter"] = min(stats["dimensions"]["min_diameter"], d)
                
                if "深度" in parsed:
                    stats["dimensions"]["depths"].append(parsed["深度"])
                
                # 记录轴向（归一化后）
                if "轴向" in parsed:
                    axis = parsed["轴向"]
                    # 归一化轴向（取绝对值最大的分量的符号）
                    abs_axis = tuple(abs(x) for x in axis)
                    max_idx = abs_axis.index(max(abs_axis))
                    normalized = tuple(1 if i == max_idx else 0 for i in range(3))
                    stats["axes"].add(normalized)
        
        # 转换set为list以便序列化
        stats["axes"] = list(stats["axes"])
        stats["axis_count"] = len(stats["axes"])
        
        return stats
    
    def analyze_topology(self, topology_relations: List[Dict]) -> Dict[str, Any]:
        """分析拓扑关系"""
        analysis = {
            "relation_counts": {},
            "feature_connections": {},
            "coaxial_groups": [],
            "parallel_groups": [],
            "adjacent_groups": [],
        }
        
        for rel in topology_relations:
            rel_type = rel.get("relation_type", "未知")
            analysis["relation_counts"][rel_type] = analysis["relation_counts"].get(rel_type, 0) + 1
            
            f1_type = rel.get("feature1_type", "")
            f2_type = rel.get("feature2_type", "")
            
            # 记录特征连接模式
            connection_key = f"{f1_type}-{f2_type}"
            analysis["feature_connections"][connection_key] = \
                analysis["feature_connections"].get(connection_key, 0) + 1
            
            # 分组记录
            if rel_type == "同轴":
                analysis["coaxial_groups"].append((rel["feature1_index"], rel["feature2_index"]))
            elif rel_type == "平行":
                analysis["parallel_groups"].append((rel["feature1_index"], rel["feature2_index"]))
            elif rel_type == "相邻":
                analysis["adjacent_groups"].append((rel["feature1_index"], rel["feature2_index"]))
        
        return analysis
    
    def analyze_patterns(self, patterns: List[Dict]) -> Dict[str, Any]:
        """分析特征模式"""
        analysis = {
            "pattern_types": {},
            "symmetry_planes": set(),
            "array_count": 0,
        }
        
        for pattern in patterns:
            p_type = pattern.get("type", "未知")
            analysis["pattern_types"][p_type] = analysis["pattern_types"].get(p_type, 0) + 1
            
            if "plane" in pattern:
                analysis["symmetry_planes"].add(pattern["plane"])
            
            if "阵列" in p_type:
                analysis["array_count"] += 1
        
        analysis["symmetry_planes"] = list(analysis["symmetry_planes"])
        return analysis
    
    def infer_part_type(self, filename: str, stats: Dict, topology: Dict, 
                        patterns: Dict) -> Tuple[str, float]:
        """推断零件类型"""
        scores = {}
        
        # 从文件名提取信息
        filename_lower = filename.lower()
        
        for part_type, rules in self.PART_TYPE_RULES.items():
            score = 0.0
            
            # 关键词匹配（权重最高）
            for kw in rules.get("keywords", []):
                if kw in filename:
                    score += 50
            
            # 特征数量匹配
            for feat_type, (min_count, max_count) in rules.get("features", {}).items():
                actual = stats["by_type"].get(feat_type, 0)
                if min_count <= actual <= max_count:
                    score += 10
                elif actual > 0:
                    score += 3  # 部分匹配
            
            # 模式匹配
            for pattern in rules.get("patterns", []):
                if pattern in str(stats.get("hole_types", {})):
                    score += 15
                if pattern in filename:
                    score += 20
            
            # 轴向数量匹配（对于管件很重要）
            if "axis_count" in rules:
                if stats.get("axis_count", 0) == rules["axis_count"]:
                    score += 25
                elif stats.get("axis_count", 0) >= rules["axis_count"] - 1:
                    score += 10
            
            scores[part_type] = score
        
        # 获取最高分
        if scores:
            best_type = max(scores, key=scores.get)
            confidence = min(scores[best_type] / 100, 1.0)
            return best_type, confidence
        
        return "通用零件", 0.3
    
    def generate_description(self, filename: str, part_type: str, stats: Dict,
                            topology: Dict, patterns: Dict) -> str:
        """生成自然语言描述"""
        descriptions = []
        
        # 基本描述
        descriptions.append(f"这是一个{part_type}类型的机械零件。")
        
        # 特征描述
        feat_desc = []
        for feat_type, count in stats["by_type"].items():
            if count > 0:
                feat_desc.append(f"{count}个{feat_type}特征")
        if feat_desc:
            descriptions.append(f"包含{', '.join(feat_desc)}。")
        
        # 孔特征详细描述
        if stats["hole_types"]:
            hole_desc = []
            for hole_type, count in stats["hole_types"].items():
                semantic = self.HOLE_TYPE_SEMANTICS.get(hole_type, hole_type)
                hole_desc.append(f"{count}个{hole_type}({semantic})")
            descriptions.append(f"孔特征详情: {', '.join(hole_desc)}。")
        
        # 尺寸描述
        dims = stats["dimensions"]
        if dims["diameters"]:
            if dims["max_diameter"] == dims["min_diameter"]:
                descriptions.append(f"孔径统一为{dims['max_diameter']:.1f}mm。")
            else:
                descriptions.append(f"孔径范围{dims['min_diameter']:.1f}mm至{dims['max_diameter']:.1f}mm。")
        
        # 拓扑特征描述
        if stats.get("axis_count", 0) >= 2:
            descriptions.append(f"具有{stats['axis_count']}个不同轴向的通道，形成多向连通结构。")
        
        # 对称性描述
        if patterns.get("symmetry_planes"):
            planes = "、".join(patterns["symmetry_planes"])
            descriptions.append(f"在{planes}上具有镜像对称性。")
        
        return " ".join(descriptions)
    
    def generate_keywords(self, filename: str, part_type: str, stats: Dict,
                         topology: Dict, patterns: Dict) -> List[str]:
        """生成搜索关键词"""
        keywords = set()
        
        # 从文件名提取
        # 移除扩展名和特殊字符
        clean_name = re.sub(r'\[.*?\]', ' ', filename)
        clean_name = re.sub(r'[_\-.]', ' ', clean_name)
        keywords.update(clean_name.split())
        
        # 零件类型
        keywords.add(part_type)
        
        # 特征类型
        for feat_type in stats["by_type"].keys():
            keywords.add(feat_type)
        
        # 孔类型
        for hole_type in stats.get("hole_types", {}).keys():
            keywords.add(hole_type)
        
        # 拓扑特征关键词
        axis_count = stats.get("axis_count", 0)
        if axis_count == 2:
            keywords.update(["两通", "直通", "接头"])
        elif axis_count == 3:
            keywords.update(["三通", "T型", "T接头", "分流"])
        elif axis_count == 4:
            keywords.update(["四通", "十字", "交叉", "cross"])
        
        # 对称性关键词
        if patterns.get("pattern_types", {}).get("镜像对称", 0) > 0:
            keywords.add("对称")
        
        # 内六角特征
        if "内六角" in filename:
            keywords.update(["内六角", "六角", "hex", "allen"])
        
        # 过滤空字符串和太短的词
        keywords = {kw for kw in keywords if len(kw) >= 2}
        
        return list(keywords)
    
    def generate_topology_signature(self, stats: Dict, topology: Dict) -> str:
        """生成拓扑特征签名，用于相似性比较"""
        parts = []
        
        # 特征类型计数
        for feat_type in sorted(stats["by_type"].keys()):
            count = stats["by_type"][feat_type]
            parts.append(f"{feat_type}:{count}")
        
        # 轴向数量
        parts.append(f"axes:{stats.get('axis_count', 0)}")
        
        # 关系类型计数
        for rel_type in sorted(topology.get("relation_counts", {}).keys()):
            count = topology["relation_counts"][rel_type]
            parts.append(f"{rel_type}:{count}")
        
        return "|".join(parts)
    
    def analyze_json(self, json_path: str) -> FeatureSemantics:
        """分析单个JSON文件"""
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        filename = data.get("file_info", {}).get("filename", Path(json_path).name)
        file_id = Path(json_path).stem
        
        # 提取各类信息
        features = data.get("features", [])
        topology_relations = data.get("topology_relations", [])
        feature_patterns = data.get("feature_patterns", [])
        
        # 分析
        stats = self.extract_feature_statistics(features)
        topology = self.analyze_topology(topology_relations)
        patterns = self.analyze_patterns(feature_patterns)
        
        # 推断零件类型
        part_type, confidence = self.infer_part_type(filename, stats, topology, patterns)
        
        # 生成语义信息
        description = self.generate_description(filename, part_type, stats, topology, patterns)
        keywords = self.generate_keywords(filename, part_type, stats, topology, patterns)
        topo_sig = self.generate_topology_signature(stats, topology)
        
        # 生成可搜索文本
        searchable_text = f"{filename} {part_type} {description} {' '.join(keywords)}"
        
        result = FeatureSemantics(
            file_id=file_id,
            filename=filename,
            part_type=part_type,
            description=description,
            keywords=keywords,
            feature_summary=stats,
            topology_signature=topo_sig,
            searchable_text=searchable_text
        )
        
        self.analyzed_parts[file_id] = result
        return result
    
    def analyze_directory(self, dir_path: str) -> List[FeatureSemantics]:
        """批量分析目录下的所有JSON文件"""
        results = []
        dir_path = Path(dir_path)
        
        for json_file in dir_path.glob("*_features.json"):
            try:
                result = self.analyze_json(str(json_file))
                results.append(result)
                print(f"✓ 分析完成: {json_file.name} -> {result.part_type}")
            except Exception as e:
                print(f"✗ 分析失败: {json_file.name} - {e}")
        
        return results
    
    def to_dict(self, semantics: FeatureSemantics) -> Dict[str, Any]:
        """转换为字典格式"""
        return {
            "file_id": semantics.file_id,
            "filename": semantics.filename,
            "part_type": semantics.part_type,
            "description": semantics.description,
            "keywords": semantics.keywords,
            "feature_summary": semantics.feature_summary,
            "topology_signature": semantics.topology_signature,
            "searchable_text": semantics.searchable_text,
        }


# 测试代码
if __name__ == "__main__":
    analyzer = FeatureSemanticAnalyzer()
    
    # 测试数据
    test_files = [
        "/mnt/user-data/uploads/AGES-50_features.json",
        "/mnt/user-data/uploads/prt0001_features.json",
        "/mnt/user-data/uploads/凹端内六角紧定螺钉_PSEP-6-16-A__features.json",
        "/mnt/user-data/uploads/内六角喉塞_PSEG-M22-A__features.json",
        "/mnt/user-data/uploads/内六角机螺钉_内六角平头带垫螺栓_PSHFP-M6-25-A__features.json",
    ]
    
    print("=" * 60)
    print("三维模型特征语义分析")
    print("=" * 60)
    
    for filepath in test_files:
        try:
            result = analyzer.analyze_json(filepath)
            print(f"\n文件: {result.filename}")
            print(f"推断类型: {result.part_type}")
            print(f"关键词: {', '.join(result.keywords[:10])}")
            print(f"描述: {result.description[:100]}...")
            print("-" * 40)
        except FileNotFoundError:
            print(f"文件未找到: {filepath}")
