#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
改进的螺纹特征分析模块
降低对非螺纹特征的误判率
"""

import logging
import math
import numpy as np
from typing import Dict, Any, List, Optional
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface, BRepAdaptor_Curve
from OCC.Core.GeomAbs import (
    GeomAbs_Cylinder, GeomAbs_Circle, GeomAbs_BSplineCurve, 
    GeomAbs_Torus, GeomAbs_Cone
)
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_EDGE, TopAbs_FACE
from OCC.Core.TopoDS import topods, TopoDS_Face, TopoDS_Shape
from OCC.Core.gp import gp_Pnt
from id2.thread_analysis_utils import ThreadAnalysisUtils
logger = logging.getLogger(__name__)
from id2.thread_verification import ThreadVerifier
class ThreadAnalyzer:
    """改进的螺纹特征分析类"""

    STANDARD_THREAD_DIAMETERS = [
        0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.5,
        3.0, 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0
    ]

    @staticmethod
    def is_thread(cylindrical_face, shape):
        """
        判断一个圆柱面是否构成螺纹 - 使用增强验证
        """
        try:
            # 1. 基本检查
            surf = BRepAdaptor_Surface(cylindrical_face)
            if surf.GetType() != GeomAbs_Cylinder:
                return False
            
            # 2. 边缘分析 (使用现有的详细分析)
            # 注意: 需要从 thread_analysis_utils 导入
            
            edge_data = ThreadAnalysisUtils.analyze_edges(cylindrical_face)
            
            # 3. ✅ 使用增强的螺纹验证器
            
            
            verification_result = ThreadVerifier.verify_thread(
                cylindrical_face, shape, edge_data
            )
            
            # 记录详细信息
            if verification_result.is_thread:
                logger.info(f"✓ 螺纹验证通过: {verification_result.positive_indicators}")
            else:
                logger.info(f"✗ 螺纹验证失败: {verification_result.rejection_reasons}")
            
            return verification_result.is_thread
            
        except Exception as e:
            logger.error(f"螺纹判断时出错: {str(e)}")
            return False
    
    @staticmethod
    def _analyze_edges(cylindrical_face: TopoDS_Face) -> Dict[str, Any]:
        """分析边缘特征"""
        result = {
            'edge_count': 0,
            'bspline_count': 0,
            'circle_count': 0,
            'z_range': 0,
            'z_coordinates': [],
            'control_points_total': 0,
            'edge_lengths': []
        }
        
        try:
            explorer = TopExp_Explorer(cylindrical_face, TopAbs_EDGE)
            z_coords = []
            
            while explorer.More():
                edge = topods.Edge(explorer.Current())
                result['edge_count'] += 1
                
                try:
                    curve_adaptor = BRepAdaptor_Curve(edge)
                    curve_type = curve_adaptor.GetType()
                    
                    if curve_type == GeomAbs_BSplineCurve:
                        result['bspline_count'] += 1
                        
                        try:
                            bspline = curve_adaptor.BSpline()
                            nb_poles = bspline.NbPoles()
                            result['control_points_total'] += nb_poles
                        except:
                            pass
                        
                        # 采样Z坐标
                        first = curve_adaptor.FirstParameter()
                        last = curve_adaptor.LastParameter()
                        
                        if math.isfinite(first) and math.isfinite(last):
                            for i in range(20):
                                t = first + (last - first) * i / 19
                                pt = curve_adaptor.Value(t)
                                z_coords.append(pt.Z())
                                
                    elif curve_type == GeomAbs_Circle:
                        result['circle_count'] += 1
                        
                except:
                    pass
                
                explorer.Next()
            
            result['z_coordinates'] = z_coords
            
            if z_coords:
                result['z_range'] = max(z_coords) - min(z_coords)
                
        except Exception as e:
            logger.error(f"边缘分析错误: {str(e)}")
            
        return result
    
    @staticmethod
    def _analyze_context(cylindrical_face: TopoDS_Face, shape: TopoDS_Shape) -> Dict:
        """分析上下文特征"""
        context = {
            'has_torus_neighbors': False,
            'has_cone_neighbors': False,
            'neighbor_count': 0
        }
        
        try:
            # 查找相邻面
            face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
            
            while face_explorer.More():
                face = topods.Face(face_explorer.Current())
                
                if not face.IsSame(cylindrical_face):
                    try:
                        surf = BRepAdaptor_Surface(face)
                        surf_type = surf.GetType()
                        
                        # 检查是否有共同边（简化判断）
                        # 这里可以通过更复杂的拓扑关系判断
                        
                        if surf_type == GeomAbs_Torus:
                            context['has_torus_neighbors'] = True
                        elif surf_type == GeomAbs_Cone:
                            context['has_cone_neighbors'] = True
                            
                    except:
                        pass
                
                face_explorer.Next()
                
        except Exception as e:
            logger.error(f"上下文分析错误: {str(e)}")
            
        return context
    
    @staticmethod
    def _check_spiral_quality(cylindrical_face: TopoDS_Face, edge_data: Dict) -> float:
        """
        检查螺旋质量
        返回0-1的评分
        """
        if edge_data['bspline_count'] == 0:
            return 0.0
        
        z_coords = edge_data.get('z_coordinates', [])
        
        if len(z_coords) < 10:
            return 0.0
        
        try:
            # 检查Z坐标的单调性
            z_diffs = [z_coords[i+1] - z_coords[i] for i in range(len(z_coords)-1)]
            
            # 统计正负变化
            positive = sum(1 for d in z_diffs if d > 0.001)
            negative = sum(1 for d in z_diffs if d < -0.001)
            
            # 单调性评分
            monotonicity = max(positive, negative) / len(z_diffs)
            
            # Z范围评分
            z_range_score = min(edge_data['z_range'] / 10, 1.0)
            
            # B样条数量评分
            bspline_score = min(edge_data['bspline_count'] / 6, 1.0)
            
            # 综合评分
            quality = monotonicity * 0.4 + z_range_score * 0.3 + bspline_score * 0.3
            
            return quality
            
        except Exception:
            return 0.0
    
    @staticmethod
    def _check_pitch_consistency(edge_data: Dict) -> float:
        """
        检查螺距一致性
        返回0-1的评分
        """
        z_coords = edge_data.get('z_coordinates', [])
        
        if len(z_coords) < 20:
            return 0.0
        
        try:
            # 将Z坐标按B样条分组（简化处理）
            groups = edge_data['bspline_count']
            
            if groups < 2:
                return 0.0
            
            # 计算每组的平均Z间隔
            points_per_group = len(z_coords) // groups
            
            pitches = []
            for i in range(groups - 1):
                start_z = np.mean(z_coords[i*points_per_group:(i+1)*points_per_group])
                end_z = np.mean(z_coords[(i+1)*points_per_group:(i+2)*points_per_group])
                pitch = abs(end_z - start_z)
                
                if pitch > 0.01:  # 忽略太小的间隔
                    pitches.append(pitch)
            
            if len(pitches) < 2:
                return 0.0
            
            # 计算螺距的标准差
            mean_pitch = np.mean(pitches)
            std_pitch = np.std(pitches)
            
            # 一致性评分（标准差越小越好）
            if mean_pitch > 0:
                consistency = 1.0 - min(std_pitch / mean_pitch, 1.0)
            else:
                consistency = 0.0
            
            return consistency
            
        except Exception:
            return 0.0
    
    @staticmethod
    def analyze_thread(cylindrical_face: TopoDS_Face, shape: TopoDS_Shape) -> Dict[str, Any]:
        """分析螺纹特征的详细属性"""
        properties = {
            '类型': '螺纹',
            '直径': 0.0,
            'B样条数量': 0,
            'Z变化范围': 0.0,
            '螺旋质量': 0.0,
            '螺距一致性': 0.0
        }
        
        try:
            surf = BRepAdaptor_Surface(cylindrical_face)
            if surf.GetType() == GeomAbs_Cylinder:
                cylinder = surf.Cylinder()
                radius = cylinder.Radius()
                properties['直径'] = round(radius * 2, 3)
                
                edge_data = ThreadAnalyzer._analyze_edges(cylindrical_face)
                properties['B样条数量'] = edge_data['bspline_count']
                properties['Z变化范围'] = round(edge_data['z_range'], 3)
                properties['螺旋质量'] = round(ThreadAnalyzer._check_spiral_quality(cylindrical_face, edge_data), 3)
                properties['螺距一致性'] = round(ThreadAnalyzer._check_pitch_consistency(edge_data), 3)
                
        except Exception as e:
            properties['错误'] = str(e)
            
        return properties