#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
螺纹验证模块 - 解决螺纹误识别问题
提供多层次的验证机制
"""

import logging
import math
from typing import Dict, Any, Tuple, List
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.GeomAbs import GeomAbs_Cylinder, GeomAbs_Torus, GeomAbs_Cone
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopoDS import topods

logger = logging.getLogger(__name__)


class ThreadVerificationResult:
    """螺纹验证结果"""
    def __init__(self):
        self.is_thread = False
        self.confidence = 0.0
        self.rejection_reasons = []  # 拒绝原因
        self.positive_indicators = []  # 正面指标
        self.scores = {}  # 各项得分


class ThreadVerifier:
    """螺纹验证器 - 多层次验证避免误判"""
    
    # 硬性排除条件
    EXCLUSION_CHECKS = {
        'has_torus_neighbor': "检测到环面邻居(O型圈槽特征)",
        'has_cone_face': "检测到锥面(倒角特征)",
        'too_short_cylinder': "圆柱体过短(长径比<0.5)",
        'too_few_bsplines': "B样条曲线数量过少(<4)",
        'no_z_variation': "Z方向无变化",
        'excessive_circular_edges': "圆形边缘过多(>50%)",
        'irregular_geometry': "几何形状不规则",
        'low_edge_density': "边缘密度过低",
    }
    
    # 置信度阈值
    MIN_CONFIDENCE = 0.80  # 提高到80%
    MIN_POSITIVE_SCORE = 0.75  # 正面指标得分阈值
    
    @staticmethod
    def verify_thread(face, shape, edge_data: Dict[str, Any]) -> ThreadVerificationResult:
        """
        多层次验证是否为螺纹
        
        Args:
            face: 候选螺纹面
            shape: 整体形状
            edge_data: 边缘分析数据
            
        Returns:
            ThreadVerificationResult: 验证结果
        """
        result = ThreadVerificationResult()
        
        try:
            # ========== 第一阶段: 硬性排除检查 ==========
            exclusions = ThreadVerifier._check_exclusions(face, shape, edge_data)
            
            if exclusions:
                result.is_thread = False
                result.confidence = 0.0
                result.rejection_reasons = exclusions
                logger.info(f"螺纹被排除: {', '.join(exclusions)}")
                return result
            
            # ========== 第二阶段: 上下文验证 ==========
            context_score = ThreadVerifier._verify_context(face, shape, edge_data)
            result.scores['context'] = context_score
            
            if context_score < 0.5:
                result.rejection_reasons.append(f"上下文得分过低: {context_score:.2f}")
            
            # ========== 第三阶段: 几何特征验证 ==========
            geometry_score = ThreadVerifier._verify_geometry(edge_data)
            result.scores['geometry'] = geometry_score
            
            if geometry_score < 0.6:
                result.rejection_reasons.append(f"几何得分过低: {geometry_score:.2f}")
            
            # ========== 第四阶段: 螺纹特定特征验证 ==========
            thread_specific_score = ThreadVerifier._verify_thread_specific_features(edge_data)
            result.scores['thread_specific'] = thread_specific_score
            
            if thread_specific_score < 0.7:
                result.rejection_reasons.append(f"螺纹特征得分过低: {thread_specific_score:.2f}")
            
            # ========== 第五阶段: 标准螺纹匹配 ==========
            standard_match_score = ThreadVerifier._verify_standard_match(face, edge_data)
            result.scores['standard_match'] = standard_match_score
            
            # ========== 综合判断 ==========
            # 加权平均
            weights = {
                'context': 0.20,
                'geometry': 0.25,
                'thread_specific': 0.40,
                'standard_match': 0.15
            }
            
            overall_confidence = sum(
                result.scores.get(key, 0) * weight 
                for key, weight in weights.items()
            )
            
            result.confidence = overall_confidence
            
            # 最终判断 - 需要同时满足多个条件
            result.is_thread = (
                overall_confidence >= ThreadVerifier.MIN_CONFIDENCE and
                len(result.rejection_reasons) == 0 and
                thread_specific_score >= 0.7 and  # 螺纹特定特征必须强
                context_score >= 0.5  # 上下文必须合理
            )
            
            if result.is_thread:
                result.positive_indicators = [
                    f"置信度: {overall_confidence:.2%}",
                    f"螺纹特征得分: {thread_specific_score:.2f}",
                    f"几何得分: {geometry_score:.2f}",
                    f"上下文得分: {context_score:.2f}"
                ]
                logger.info(f"✓ 验证为螺纹，置信度={overall_confidence:.2%}")
            else:
                logger.info(f"✗ 不是螺纹，置信度={overall_confidence:.2%}, "
                          f"原因: {result.rejection_reasons}")
            
            return result
            
        except Exception as e:
            logger.error(f"螺纹验证时出错: {str(e)}")
            result.is_thread = False
            result.confidence = 0.0
            result.rejection_reasons.append(f"验证错误: {str(e)}")
            return result
    
    @staticmethod
    def _check_exclusions(face, shape, edge_data: Dict) -> List[str]:
        """
        硬性排除条件检查
        
        Returns:
            List[str]: 排除原因列表，空列表表示未被排除
        """
        exclusions = []
        
        try:
            surf = BRepAdaptor_Surface(face)
            
            # 1. 必须是圆柱面
            if surf.GetType() != GeomAbs_Cylinder:
                exclusions.append("非圆柱面")
                return exclusions
            
            cylinder = surf.Cylinder()
            radius = cylinder.Radius()
            diameter = radius * 2
            
            # 2. 检查相邻面 - 排除O型圈槽和倒角
            has_torus, has_cone = ThreadVerifier._check_adjacent_faces(face, shape)
            
            if has_torus:
                exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['has_torus_neighbor'])
            
            if has_cone:
                exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['has_cone_face'])
            
            # 3. 检查长径比
            z_range = edge_data.get('z_range', 0)
            if z_range > 0:
                length_to_diameter = z_range / diameter
                if length_to_diameter < 0.5:
                    exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['too_short_cylinder'])
            
            # 4. B样条数量检查 - 螺纹必须有足够的B样条
            bspline_count = edge_data.get('bspline_count', 0)
            if bspline_count < 4:
                exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['too_few_bsplines'])
            
            # 5. Z方向变化检查
            if not edge_data.get('has_z_variation', False):
                exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['no_z_variation'])
            
            # 6. 圆形边缘过多检查 - 螺纹不应该有太多圆形边缘
            edge_count = edge_data.get('edge_count', 0)
            circle_count = edge_data.get('circle_count', 0)
            
            if edge_count > 0:
                circle_ratio = circle_count / edge_count
                if circle_ratio > 0.5:  # 超过50%
                    exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['excessive_circular_edges'])
            
            # 7. 直线边缘过多检查
            line_count = edge_data.get('line_count', 0)
            if edge_count > 0:
                line_ratio = line_count / edge_count
                if line_ratio > 0.5:  # 超过50%
                    exclusions.append("直线边缘过多(非螺旋特征)")
            
            # 8. 边缘密度检查
            edge_density = edge_data.get('edge_density', 0)
            if edge_density < 2.0:
                exclusions.append(ThreadVerifier.EXCLUSION_CHECKS['low_edge_density'])
            
            # 9. 拓扑错误检查
            topo_errors = edge_data.get('topological_errors', 0)
            if topo_errors > 3:
                exclusions.append(f"拓扑错误过多({topo_errors})")
            
            return exclusions
            
        except Exception as e:
            logger.error(f"排除检查时出错: {str(e)}")
            exclusions.append(f"检查错误: {str(e)}")
            return exclusions
    
    @staticmethod
    def _check_adjacent_faces(face, shape) -> Tuple[bool, bool]:
        """
        检查相邻面类型
        
        Returns:
            Tuple[bool, bool]: (是否有环面, 是否有锥面)
        """
        has_torus = False
        has_cone = False
        
        try:
            # 获取所有面，检查是否有环面或锥面与圆柱面接近
            face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
            
            # 获取候选面的中心位置
            from OCC.Core.BRepGProp import brepgprop
            from OCC.Core.GProp import GProp_GProps
            
            props = GProp_GProps()
            brepgprop.SurfaceProperties(face, props)
            face_center = props.CentreOfMass()
            
            while face_explorer.More():
                other_face = topods.Face(face_explorer.Current())
                
                if not face.IsSame(other_face):
                    other_surf = BRepAdaptor_Surface(other_face)
                    surf_type = other_surf.GetType()
                    
                    # 检查类型
                    if surf_type == GeomAbs_Torus:
                        # 计算距离，如果很近则可能是O型圈槽
                        other_props = GProp_GProps()
                        brepgprop.SurfaceProperties(other_face, other_props)
                        other_center = other_props.CentreOfMass()
                        
                        distance = face_center.Distance(other_center)
                        if distance < 10.0:  # 10mm以内认为是相邻
                            has_torus = True
                            logger.debug("发现相邻环面(O型圈槽)")
                    
                    elif surf_type == GeomAbs_Cone:
                        # 锥面可能是倒角
                        other_props = GProp_GProps()
                        brepgprop.SurfaceProperties(other_face, other_props)
                        other_center = other_props.CentreOfMass()
                        
                        distance = face_center.Distance(other_center)
                        if distance < 5.0:  # 5mm以内
                            has_cone = True
                            logger.debug("发现相邻锥面(倒角)")
                
                face_explorer.Next()
            
            return has_torus, has_cone
            
        except Exception as e:
            logger.warning(f"检查相邻面时出错: {str(e)}")
            return False, False
    
    @staticmethod
    def _verify_context(face, shape, edge_data: Dict) -> float:
        """
        验证上下文合理性
        
        Returns:
            float: 上下文得分 (0-1)
        """
        score = 0.0
        
        try:
            # 1. 圆柱体的长径比应该合理
            surf = BRepAdaptor_Surface(face)
            cylinder = surf.Cylinder()
            diameter = cylinder.Radius() * 2
            z_range = edge_data.get('z_range', 0)
            
            if z_range > 0:
                length_to_diameter = z_range / diameter
                
                # 螺纹的长径比通常在0.5到20之间
                if 0.5 <= length_to_diameter <= 20:
                    score += 0.4
                elif 0.3 <= length_to_diameter <= 30:
                    score += 0.2
            
            # 2. 直径应该在合理范围内
            if 1.0 <= diameter <= 100.0:  # 1mm到100mm
                score += 0.3
            elif 0.5 <= diameter <= 150.0:
                score += 0.15
            
            # 3. 表面积应该合理
            surface_area = edge_data.get('surface_area', 0)
            if surface_area > 0:
                # 理论圆柱面积
                theoretical_area = math.pi * diameter * z_range
                if theoretical_area > 0:
                    area_ratio = surface_area / theoretical_area
                    # 实际面积应该接近理论值
                    if 0.8 <= area_ratio <= 1.2:
                        score += 0.3
                    elif 0.6 <= area_ratio <= 1.5:
                        score += 0.15
            
            return min(score, 1.0)
            
        except Exception as e:
            logger.warning(f"上下文验证时出错: {str(e)}")
            return 0.0
    
    @staticmethod
    def _verify_geometry(edge_data: Dict) -> float:
        """
        验证几何特征
        
        Returns:
            float: 几何得分 (0-1)
        """
        score = 0.0
        
        try:
            # 1. B样条数量
            bspline_count = edge_data.get('bspline_count', 0)
            edge_count = max(edge_data.get('edge_count', 1), 1)
            
            bspline_ratio = bspline_count / edge_count
            
            if bspline_ratio >= 0.8:  # 80%以上
                score += 0.3
            elif bspline_ratio >= 0.6:
                score += 0.2
            elif bspline_ratio >= 0.4:
                score += 0.1
            
            # 2. Z方向变化
            z_range = edge_data.get('z_range', 0)
            if z_range > 1.0:
                score += 0.2
            elif z_range > 0.5:
                score += 0.1
            
            # 3. 边缘密度
            edge_density = edge_data.get('edge_density', 0)
            if edge_density > 5:
                score += 0.2
            elif edge_density > 3:
                score += 0.15
            elif edge_density > 2:
                score += 0.1
            
            # 4. 螺旋连续性
            helix_continuity = edge_data.get('helical_continuity', 0)
            if helix_continuity > 0.8:
                score += 0.3
            elif helix_continuity > 0.6:
                score += 0.2
            elif helix_continuity > 0.4:
                score += 0.1
            
            return min(score, 1.0)
            
        except Exception as e:
            logger.warning(f"几何验证时出错: {str(e)}")
            return 0.0
    
    @staticmethod
    def _verify_thread_specific_features(edge_data: Dict) -> float:
        """
        验证螺纹特定特征
        
        Returns:
            float: 螺纹特征得分 (0-1)
        """
        score = 0.0
        
        try:
            # 1. 螺旋角度
            helix_angle = edge_data.get('helix_angle', 0)
            if 1.0 <= helix_angle <= 45.0:
                score += 0.25
            elif 0.5 <= helix_angle <= 60.0:
                score += 0.15
            
            # 2. 周期性
            z_period = edge_data.get('z_period', 0)
            z_period_confidence = edge_data.get('z_period_confidence', 0)
            
            if 0.1 < z_period < 10.0 and z_period_confidence > 0.6:
                score += 0.30
            elif 0.05 < z_period < 15.0 and z_period_confidence > 0.4:
                score += 0.15
            
            # 3. 螺距一致性
            pitch_consistency = edge_data.get('pitch_consistency', 0)
            if pitch_consistency > 0.7:
                score += 0.25
            elif pitch_consistency > 0.5:
                score += 0.15
            
            # 4. 控制点密度
            bspline_count = edge_data.get('bspline_count', 0)
            control_points_total = edge_data.get('control_points_total', 0)
            
            if bspline_count > 0:
                avg_control_points = control_points_total / bspline_count
                if avg_control_points >= 30:
                    score += 0.20
                elif avg_control_points >= 20:
                    score += 0.10
            
            return min(score, 1.0)
            
        except Exception as e:
            logger.warning(f"螺纹特征验证时出错: {str(e)}")
            return 0.0
    
    @staticmethod
    def _verify_standard_match(face, edge_data: Dict) -> float:
        """
        验证是否匹配标准螺纹
        
        Returns:
            float: 标准匹配得分 (0-1)
        """
        score = 0.0
        
        try:
            from thread_standards import ThreadStandards
            
            # 获取直径
            surf = BRepAdaptor_Surface(face)
            cylinder = surf.Cylinder()
            diameter = cylinder.Radius() * 2
            
            # 检查是否匹配标准螺纹直径
            standard_info = ThreadStandards.get_thread_standard_info(diameter)
            
            if standard_info and standard_info.get('standard_diameter', 0) > 0:
                std_diameter = standard_info['standard_diameter']
                diameter_diff = abs(diameter - std_diameter)
                
                # 直径匹配度
                if diameter_diff < 0.1:  # 0.1mm以内
                    score += 0.5
                elif diameter_diff < 0.3:  # 0.3mm以内
                    score += 0.3
                elif diameter_diff < 0.5:
                    score += 0.1
                
                # 螺距匹配度
                z_period = edge_data.get('z_period', 0)
                std_pitch = standard_info.get('pitch', 0)
                
                if z_period > 0 and std_pitch > 0:
                    pitch_diff = abs(z_period - std_pitch)
                    
                    if pitch_diff < 0.05:  # 0.05mm以内
                        score += 0.5
                    elif pitch_diff < 0.1:
                        score += 0.3
                    elif pitch_diff < 0.2:
                        score += 0.1
            
            return min(score, 1.0)
            
        except Exception as e:
            logger.warning(f"标准匹配验证时出错: {str(e)}")
            return 0.0