#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
BRepAdaptor_Surface B样条曲面参数提取器
从BRepAdaptor_Surface对象提取B样条曲面的完整参数信息并输出为JSON格式
"""

import json
from typing import Dict, List, Any, Optional
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.GeomAbs import GeomAbs_BSplineSurface
from OCC.Core.Geom import Geom_BSplineSurface
from OCC.Core.gp import gp_Pnt
import numpy as np

class BSplineSurfaceParameterExtractor:
    """从BRepAdaptor_Surface提取B样条曲面参数"""
    
    @staticmethod
    def extract_parameters(adaptor_surface: BRepAdaptor_Surface) -> Optional[Dict[str, Any]]:
        """
        从BRepAdaptor_Surface提取B样条曲面参数
        
        Args:
            adaptor_surface: BRepAdaptor_Surface对象
            
        Returns:
            Dict: B样条曲面参数字典，如果不是B样条曲面则返回None
        """
        # 检查是否为B样条曲面
        if adaptor_surface.GetType() != GeomAbs_BSplineSurface:
            return None
        
        # 获取B样条曲面
        bspline_surface = adaptor_surface.BSpline()
        
        # 提取所有参数
        parameters = {
            "surface_type": "BSplineSurface",
            
            # 度数信息
            "degree": {
                "u_degree": bspline_surface.UDegree(),
                "v_degree": bspline_surface.VDegree()
            },
            
            # 节点向量
            "knots": {
                "u_knots": BSplineSurfaceParameterExtractor._extract_knots(bspline_surface, 'U'),
                "v_knots": BSplineSurfaceParameterExtractor._extract_knots(bspline_surface, 'V'),
                "u_multiplicities": BSplineSurfaceParameterExtractor._extract_multiplicities(bspline_surface, 'U'),
                "v_multiplicities": BSplineSurfaceParameterExtractor._extract_multiplicities(bspline_surface, 'V')
            },
            
            # 控制点
            "control_points": BSplineSurfaceParameterExtractor._extract_control_points(bspline_surface),
            
            # 权重（如果是有理B样条）
            "weights": BSplineSurfaceParameterExtractor._extract_weights(bspline_surface),
            
            # 参数域
            "parameter_range": {
                "u_knot_range": [bspline_surface.FirstUKnotIndex(), bspline_surface.LastUKnotIndex()],
                "v_knot_range": [bspline_surface.FirstVKnotIndex(), bspline_surface.LastVKnotIndex()],
                "u_parameter_bounds": [adaptor_surface.FirstUParameter(), adaptor_surface.LastUParameter()],
                "v_parameter_bounds": [adaptor_surface.FirstVParameter(), adaptor_surface.LastVParameter()]
            },
            
            # 曲面属性
            "properties": {
                "is_u_periodic": bspline_surface.IsUPeriodic(),
                "is_v_periodic": bspline_surface.IsVPeriodic(),
                "is_u_rational": bspline_surface.IsURational(),
                "is_v_rational": bspline_surface.IsVRational(),
                "is_u_closed": bspline_surface.IsUClosed(),
                "is_v_closed": bspline_surface.IsVClosed(),
                "continuity": str(bspline_surface.Continuity())
            },
            
            # 曲面尺寸
            "dimensions": {
                "num_u_poles": bspline_surface.NbUPoles(),
                "num_v_poles": bspline_surface.NbVPoles(),
                "num_u_knots": bspline_surface.NbUKnots(),
                "num_v_knots": bspline_surface.NbVKnots()
            },
            
            # 几何特征
            "geometric_features": BSplineSurfaceParameterExtractor._extract_geometric_features(
                bspline_surface, adaptor_surface
            )
        }
        
        return parameters
    
    @staticmethod
    def _extract_knots(surface: Geom_BSplineSurface, direction: str) -> List[float]:
        """提取节点向量"""
        knots = []
        if direction == 'U':
            for i in range(1, surface.NbUKnots() + 1):
                knots.append(surface.UKnot(i))
        else:
            for i in range(1, surface.NbVKnots() + 1):
                knots.append(surface.VKnot(i))
        return knots
    
    @staticmethod
    def _extract_multiplicities(surface: Geom_BSplineSurface, direction: str) -> List[int]:
        """提取节点重复度"""
        multiplicities = []
        if direction == 'U':
            for i in range(1, surface.NbUKnots() + 1):
                multiplicities.append(surface.UMultiplicity(i))
        else:
            for i in range(1, surface.NbVKnots() + 1):
                multiplicities.append(surface.VMultiplicity(i))
        return multiplicities
    
    @staticmethod
    def _extract_control_points(surface: Geom_BSplineSurface) -> List[List[Dict[str, float]]]:
        """提取控制点网格"""
        control_points = []
        
        for i in range(1, surface.NbUPoles() + 1):
            row = []
            for j in range(1, surface.NbVPoles() + 1):
                pole = surface.Pole(i, j)
                row.append({
                    "x": pole.X(),
                    "y": pole.Y(),
                    "z": pole.Z()
                })
            control_points.append(row)
            
        return control_points
    
    @staticmethod
    def _extract_weights(surface: Geom_BSplineSurface) -> Optional[List[List[float]]]:
        """提取权重（用于有理B样条）"""
        if not (surface.IsURational() or surface.IsVRational()):
            return None
            
        weights = []
        for i in range(1, surface.NbUPoles() + 1):
            row = []
            for j in range(1, surface.NbVPoles() + 1):
                row.append(surface.Weight(i, j))
            weights.append(row)
            
        return weights
    
    @staticmethod
    def _extract_geometric_features(surface: Geom_BSplineSurface, 
                                   adaptor: BRepAdaptor_Surface) -> Dict[str, Any]:
        """提取几何特征"""
        u_min, u_max = adaptor.FirstUParameter(), adaptor.LastUParameter()
        v_min, v_max = adaptor.FirstVParameter(), adaptor.LastVParameter()
        
        # 角点
        corners = []
        for u in [u_min, u_max]:
            for v in [v_min, v_max]:
                point = gp_Pnt()
                surface.D0(u, v, point)
                corners.append({
                    "u": u,
                    "v": v,
                    "point": {"x": point.X(), "y": point.Y(), "z": point.Z()}
                })
        
        # 中心点
        u_mid = (u_min + u_max) / 2
        v_mid = (v_min + v_max) / 2
        center_point = gp_Pnt()
        surface.D0(u_mid, v_mid, center_point)
        
        # 边界框
        bbox = BSplineSurfaceParameterExtractor._calculate_bounding_box(surface, adaptor)
        
        # 曲率特征（在几个关键点）
        curvature_info = BSplineSurfaceParameterExtractor._extract_curvature_info(surface, adaptor)
        
        return {
            "corner_points": corners,
            "center_point": {
                "u": u_mid,
                "v": v_mid,
                "point": {"x": center_point.X(), "y": center_point.Y(), "z": center_point.Z()}
            },
            "bounding_box": bbox,
            "curvature_info": curvature_info
        }
    
    @staticmethod
    def _calculate_bounding_box(surface: Geom_BSplineSurface, 
                               adaptor: BRepAdaptor_Surface) -> Dict[str, float]:
        """计算边界框"""
        u_min, u_max = adaptor.FirstUParameter(), adaptor.LastUParameter()
        v_min, v_max = adaptor.FirstVParameter(), adaptor.LastVParameter()
        
        # 采样点用于估算边界框
        sample_u = np.linspace(u_min, u_max, 20)
        sample_v = np.linspace(v_min, v_max, 20)
        
        x_coords = []
        y_coords = []
        z_coords = []
        
        for u in sample_u:
            for v in sample_v:
                point = gp_Pnt()
                surface.D0(u, v, point)
                x_coords.append(point.X())
                y_coords.append(point.Y())
                z_coords.append(point.Z())
        
        return {
            "x_min": min(x_coords),
            "x_max": max(x_coords),
            "y_min": min(y_coords),
            "y_max": max(y_coords),
            "z_min": min(z_coords),
            "z_max": max(z_coords),
            "diagonal_length": np.sqrt(
                (max(x_coords) - min(x_coords))**2 + 
                (max(y_coords) - min(y_coords))**2 + 
                (max(z_coords) - min(z_coords))**2
            )
        }
    
    @staticmethod
    def _extract_curvature_info(surface: Geom_BSplineSurface, 
                               adaptor: BRepAdaptor_Surface) -> Dict[str, Any]:
        """提取曲率信息"""
        u_min, u_max = adaptor.FirstUParameter(), adaptor.LastUParameter()
        v_min, v_max = adaptor.FirstVParameter(), adaptor.LastVParameter()
        
        # 在几个关键点计算曲率
        sample_points = [
            (u_min, v_min),
            (u_min, v_max),
            (u_max, v_min),
            (u_max, v_max),
            ((u_min + u_max)/2, (v_min + v_max)/2)
        ]
        
        curvature_data = []
        for u, v in sample_points:
            try:
                # 获取一阶和二阶导数
                p = gp_Pnt()
                d1u = gp_Vec()
                d1v = gp_Vec()
                d2u = gp_Vec()
                d2v = gp_Vec()
                d2uv = gp_Vec()
                
                surface.D2(u, v, p, d1u, d1v, d2u, d2v, d2uv)
                
                # 计算法向量
                normal = d1u.Crossed(d1v)
                if normal.Magnitude() > 1e-10:
                    normal.Normalize()
                    
                    curvature_data.append({
                        "u": u,
                        "v": v,
                        "normal": {"x": normal.X(), "y": normal.Y(), "z": normal.Z()},
                        "d1u_magnitude": d1u.Magnitude(),
                        "d1v_magnitude": d1v.Magnitude()
                    })
            except:
                # 如果计算失败，跳过该点
                pass
        
        return {
            "sample_points": curvature_data,
            "num_samples": len(curvature_data)
        }
    
    @staticmethod
    def save_to_json(parameters: Dict[str, Any], output_path: str) -> None:
        """
        保存参数到JSON文件
        
        Args:
            parameters: 参数字典
            output_path: 输出文件路径
        """
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(parameters, f, indent=2, ensure_ascii=False)
        
        print(f"成功保存B样条曲面参数到: {output_path}")

# 使用示例函数
def extract_bspline_from_adaptor_surface(adaptor_surface: BRepAdaptor_Surface) -> Optional[Dict[str, Any]]:
    """
    便捷函数：从BRepAdaptor_Surface提取B样条参数
    
    Args:
        adaptor_surface: BRepAdaptor_Surface对象
        
    Returns:
        Dict: B样条参数字典，如果不是B样条曲面则返回None
    """
    return BSplineSurfaceParameterExtractor.extract_parameters(adaptor_surface)

def process_multiple_surfaces(adaptor_surfaces: List[BRepAdaptor_Surface], 
                            output_file: str = "bspline_surfaces.json") -> None:
    """
    处理多个BRepAdaptor_Surface并保存到JSON
    
    Args:
        adaptor_surfaces: BRepAdaptor_Surface对象列表
        output_file: 输出JSON文件路径
    """
    all_surfaces = []
    
    for idx, adaptor_surface in enumerate(adaptor_surfaces):
        params = BSplineSurfaceParameterExtractor.extract_parameters(adaptor_surface)
        if params:
            params["surface_index"] = idx
            all_surfaces.append(params)
    
    output_data = {
        "total_surfaces": len(adaptor_surfaces),
        "bspline_surfaces": len(all_surfaces),
        "surfaces": all_surfaces
    }
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    print(f"处理了 {len(adaptor_surfaces)} 个曲面，其中 {len(all_surfaces)} 个是B样条曲面")
    print(f"结果保存到: {output_file}")

# 使用示例
if __name__ == "__main__":
    # 假设您已经有了一个BRepAdaptor_Surface对象
    # adaptor_surface = BRepAdaptor_Surface(face)
    
    # 提取参数
    # params = extract_bspline_from_adaptor_surface(adaptor_surface)
    
    # 如果是B样条曲面，保存参数
    # if params:
    #     BSplineSurfaceParameterExtractor.save_to_json(params, "bspline_params.json")
    # else:
    #     print("该曲面不是B样条曲面")
    
    pass