#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
特征输出模块，用于将识别的特征和拓扑关系写入到文件
兼容评分系统的版本
"""

import os
import json
import logging
import csv
from datetime import datetime
from typing import List, Dict, Any, Optional

logger = logging.getLogger(__name__)

class FeatureWriter:
    """特征输出类，负责将识别的特征和拓扑关系写入不同格式的文件"""
    
    def __init__(self, output_dir: str):
        """
        初始化特征输出器
        
        Args:
            output_dir: 输出目录
        """
        self.output_dir = output_dir
        
        # 确保输出目录存在
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    
    def write_features(self, features: List[Any], filename: str, 
                      topology_relations: Optional[List[Dict]] = None,
                      feature_patterns: Optional[List[Dict]] = None,
                      score_info: Optional[Dict] = None):
        """
        将特征和拓扑关系写入到文件
        
        Args:
            features: 机械特征对象列表
            filename: 基础文件名
            topology_relations: 拓扑关系列表
            feature_patterns: 特征模式列表
            score_info: 评分信息
        """
        # 写入JSON格式
        self.write_json_with_topology(
            features, filename, topology_relations, 
            feature_patterns, score_info=score_info
        )
        
        # 写入CSV格式（可选）
        # self.write_csv(features)
        
        logger.info(f"共将 {len(features)} 个特征写入到 {self.output_dir} 目录")
        if topology_relations:
            logger.info(f"共将 {len(topology_relations)} 个拓扑关系写入到文件")
    
    def write_json_with_topology(self, features: List[Any], filename: str,
                                topology_relations: Optional[List[Dict]] = None,
                                feature_patterns: Optional[List[Dict]] = None,
                                topology_graph: Optional[Dict] = None,
                                score_info: Optional[Dict] = None):
        """
        将特征和拓扑关系写入JSON文件 - 优化版
        
        优化点:
        1. 减少元数据冗余
        2. 按需包含可选信息
        3. 优化数据结构层次
        """
        # 构建输出数据结构 - 扁平化设计
        output_data = {
            'metadata': {
                'filename': filename,
                'timestamp': datetime.now().isoformat(),
                'counts': {
                    'features': len(features),
                    'relations': len(topology_relations) if topology_relations else 0,
                    'patterns': len(feature_patterns) if feature_patterns else 0
                }
            },
            'features': [],
            'topology': {}
        }
        
        # 可选:添加评分信息
        if score_info:
            output_data['metadata']['scores'] = score_info
        
        # 处理特征 - 最小化数据
        feature_types = {}
        for i, feature in enumerate(features):
            feature_data = self._extract_feature_data(feature, i)
            if feature_data:
                output_data['features'].append(feature_data)
                
                # 统计特征类型(不单独存储分组)
                ft = feature_data['type']
                feature_types[ft] = feature_types.get(ft, 0) + 1
        
        # 处理拓扑关系
        if topology_relations:
            output_data['topology']['relations'] = [
                self._extract_relation_data(r) for r in topology_relations
            ]
        
        # 处理特征模式
        if feature_patterns:
            output_data['topology']['patterns'] = feature_patterns
        
        # 可选:添加拓扑图
        if topology_graph:
            output_data['topology']['graph'] = topology_graph
        
        # 添加简化的统计信息
        output_data['statistics'] = {
            'feature_types': feature_types,
            'relation_summary': self._summarize_relations(topology_relations)
        }
        
        # 写入JSON文件
        output_file = os.path.join(self.output_dir, f'{filename}_特征和拓扑.json')
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"特征和拓扑数据已保存到 {output_file}")

    def _summarize_relations(self, topology_relations: Optional[List[Dict]]) -> Dict:
        """
        生成关系摘要 - 避免存储详细的relation_types统计
        
        Returns:
            简化的关系摘要
        """
        if not topology_relations:
            return {}
        
        summary = {
            'total': len(topology_relations),
            'multi_relation_count': 0,
            'by_type': {}
        }
        
        for relation in topology_relations:
            # 统计多重关系
            if relation.get('is_multi'):
                summary['multi_relation_count'] += 1
            
            # 统计关系类型
            rel_type = relation['relation_type']
            if ',' in rel_type:  # 多重关系
                types = [t.strip() for t in rel_type.split(',')]
                for t in types:
                    summary['by_type'][t] = summary['by_type'].get(t, 0) + 1
            else:
                summary['by_type'][rel_type] = summary['by_type'].get(rel_type, 0) + 1
        
        return summary

    
    def _extract_feature_data(self, feature: Any, index: int) -> Optional[Dict[str, Any]]:
        """
        安全地提取特征数据
        
        Args:
            feature: 特征对象（可能是不同类型）
            index: 特征索引
            
        Returns:
            特征数据字典
        """
        feature_data = {
            'id': index,
            'type': '未知',
            'properties': {}
        }
        
        # 处理不同类型的特征对象
        if isinstance(feature, str):
            # 如果是字符串，可能是特征类型名
            feature_data['type'] = feature
            logger.warning(f"特征 {index} 是字符串类型: {feature}")
            
        elif isinstance(feature, dict):
            # 如果是字典，直接提取信息
            feature_data['type'] = feature.get('type', '未知')
            feature_data['properties'] = feature.get('properties', {})
            if 'score_info' in feature:
                feature_data['score_info'] = feature['score_info']
                
        elif hasattr(feature, '__dict__'):
            # 如果是对象，尝试访问其属性
            # 尝试不同的属性名
            type_attrs = ['特征类型', 'feature_type', 'type', '类型']
            for attr in type_attrs:
                if hasattr(feature, attr):
                    feature_data['type'] = getattr(feature, attr)
                    break
            
            prop_attrs = ['属性', 'properties', 'props']
            for attr in prop_attrs:
                if hasattr(feature, attr):
                    props = getattr(feature, attr)
                    if isinstance(props, dict):
                        feature_data['properties'] = props
                    break
            
            # 提取评分信息（如果有）
            if hasattr(feature, 'score_info'):
                feature_data['score_info'] = feature.score_info
            elif hasattr(feature, '评分信息'):
                feature_data['score_info'] = getattr(feature, '评分信息')
            
            # 提取STEP信息（如果有）
            if hasattr(feature, 'step_info'):
                feature_data['step_info'] = feature.step_info
        else:
            logger.warning(f"未知的特征类型: {type(feature)}")
            return None
        
        return feature_data
    
    def _extract_relation_data(self, relation: Dict) -> Dict[str, Any]:
        """
        提取拓扑关系数据 - 优化版
        
        优化点:
        1. 避免重复字段
        2. 统一数据结构
        """
        relation_data = {
            'feature1_id': relation['feature1_index'],
            'feature2_id': relation['feature2_index'],
            'relation': relation['relation_type'],
            'multi': relation.get('is_multi', False)
        }
        
        # 只在有详细信息时添加
        if relation.get('details'):
            relation_data['details'] = relation['details']
        
        return relation_data
    
    def _generate_statistics(self, features: List[Dict], 
                           topology_relations: Optional[List[Dict]],
                           feature_patterns: Optional[List[Dict]]) -> Dict[str, Any]:
        """
        生成统计信息
        
        Args:
            features: 特征列表
            topology_relations: 拓扑关系列表
            feature_patterns: 特征模式列表
            
        Returns:
            统计信息字典
        """
        stats = {
            'feature_types': {},
            'relation_types': {},
            'pattern_types': {},
            'multi_relation_pairs': 0,
            'score_statistics': {}
        }
        
        # 特征类型统计
        for feature in features:
            feature_type = feature.get('type', '未知')
            stats['feature_types'][feature_type] = stats['feature_types'].get(feature_type, 0) + 1
            
            # 收集评分统计
            if 'score_info' in feature:
                score_info = feature['score_info']
                if '置信度' in score_info:
                    if 'confidences' not in stats['score_statistics']:
                        stats['score_statistics']['confidences'] = []
                    stats['score_statistics']['confidences'].append(score_info['置信度'])
        
        # 计算平均置信度
        if 'confidences' in stats['score_statistics']:
            confidences = stats['score_statistics']['confidences']
            stats['score_statistics']['avg_confidence'] = sum(confidences) / len(confidences)
            stats['score_statistics']['min_confidence'] = min(confidences)
            stats['score_statistics']['max_confidence'] = max(confidences)
            del stats['score_statistics']['confidences']  # 删除原始列表以节省空间
        
        # 拓扑关系类型统计
        if topology_relations:
            for relation in topology_relations:
                if 'relation_type_list' in relation:
                    if len(relation['relation_type_list']) > 1:
                        stats['multi_relation_pairs'] += 1
                    for rel_type in relation['relation_type_list']:
                        stats['relation_types'][rel_type] = stats['relation_types'].get(rel_type, 0) + 1
                else:
                    rel_type = relation.get('relation_type', '未知')
                    stats['relation_types'][rel_type] = stats['relation_types'].get(rel_type, 0) + 1
        
        # 特征模式类型统计
        if feature_patterns:
            for pattern in feature_patterns:
                pattern_type = pattern.get('type', '未知')
                stats['pattern_types'][pattern_type] = stats['pattern_types'].get(pattern_type, 0) + 1
        
        return stats
    
    def write_csv(self, features: List[Any]):
        """
        将特征写入CSV文件
        
        Args:
            features: 机械特征对象列表
        """
        # 首先提取所有特征数据
        feature_data_list = []
        for i, feature in enumerate(features):
            feature_data = self._extract_feature_data(feature, i)
            if feature_data:
                feature_data_list.append(feature_data)
        
        # 按特征类型分组
        features_by_type = {}
        for feature_data in feature_data_list:
            feature_type = feature_data['type']
            if feature_type not in features_by_type:
                features_by_type[feature_type] = []
            features_by_type[feature_type].append(feature_data)
        
        # 每种特征类型输出一个CSV文件
        for feature_type, type_features in features_by_type.items():
            if not type_features:
                continue
            
            # 找出所有可能的属性
            all_properties = set()
            for feature_data in type_features:
                all_properties.update(feature_data.get('properties', {}).keys())
                # 添加评分信息字段
                if 'score_info' in feature_data:
                    all_properties.update(feature_data['score_info'].keys())
            
            # 排序属性名称
            property_names = sorted(all_properties)
            
            # 创建CSV文件
            csv_file = os.path.join(self.output_dir, f'{feature_type}_特征.csv')
            with open(csv_file, 'w', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                
                # 写入标题行
                header = ['特征ID', '特征类型'] + property_names
                writer.writerow(header)
                
                # 写入每个特征
                for feature_data in type_features:
                    row = [feature_data['id'], feature_type]
                    properties = feature_data.get('properties', {})
                    score_info = feature_data.get('score_info', {})
                    
                    for prop in property_names:
                        # 优先从properties中获取，否则从score_info中获取
                        value = properties.get(prop, score_info.get(prop, ''))
                        row.append(value)
                    
                    writer.writerow(row)
            
            logger.info(f"{feature_type}特征数据已保存到 {csv_file}")
    
    def export_score_report(self, score_data: Dict, filename: str):
        """
        导出评分报告
        
        Args:
            score_data: 评分数据
            filename: 基础文件名
        """
        output_file = os.path.join(self.output_dir, f'{filename}_评分报告.json')
        
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(score_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"评分报告已保存到 {output_file}")
    
    def export_topology_graph_visualization(self, topology_graph: Dict, filename: str):
        """
        导出拓扑图的可视化数据
        
        Args:
            topology_graph: 拓扑图数据
            filename: 基础文件名
        """
        if not topology_graph:
            return
        
        # 生成适合可视化的格式
        vis_data = {
            'nodes': [],
            'links': []
        }
        
        # 处理节点
        for node in topology_graph.get('nodes', []):
            vis_node = {
                'id': node['id'],
                'label': f"{node.get('type', '未知')}_{node['id']}",
                'type': node.get('type', '未知'),
                'properties': node.get('properties', {})
            }
            vis_data['nodes'].append(vis_node)
        
        # 处理边
        for edge in topology_graph.get('edges', []):
            vis_link = {
                'source': edge.get('source', -1),
                'target': edge.get('target', -1),
                'type': edge.get('type', '未知'),
                'label': edge.get('type', '未知'),
                'details': edge.get('details', {})
            }
            vis_data['links'].append(vis_link)
        
        # 写入文件
        output_file = os.path.join(self.output_dir, f'{filename}_拓扑图.json')
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(vis_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"拓扑图可视化数据已保存到 {output_file}")