#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
PRT特征识别API客户端测试脚本
用于测试API服务的各种功能
"""

import requests
import time
import json
from pathlib import Path
import argparse


class PRTFeatureAPIClient:
    """API客户端类"""
    
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url.rstrip('/')
        self.session = requests.Session()
    
    def check_health(self):
        """健康检查"""
        try:
            response = self.session.get(f"{self.base_url}/health")
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"健康检查失败: {str(e)}")
            return None
    
    def upload_file(self, file_path: str, 
                   enable_topology: bool = True,
                   enable_process_planning: bool = False):
        """上传文件进行特征提取"""
        try:
            with open(file_path, 'rb') as f:
                files = {'file': (Path(file_path).name, f, 'application/octet-stream')}
                params = {
                    'enable_topology_relations': enable_topology,
                    'enable_process_planning': enable_process_planning
                }
                
                response = self.session.post(
                    f"{self.base_url}/api/v1/extract/upload",
                    files=files,
                    params=params
                )
                response.raise_for_status()
                return response.json()
        except Exception as e:
            print(f"文件上传失败: {str(e)}")
            return None
    
    def get_job_status(self, job_id: str):
        """获取任务状态"""
        try:
            response = self.session.get(f"{self.base_url}/api/v1/jobs/{job_id}")
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"获取任务状态失败: {str(e)}")
            return None
    
    def get_job_result(self, job_id: str, output_path: str = None):
        """获取任务结果"""
        try:
            response = self.session.get(f"{self.base_url}/api/v1/jobs/{job_id}/result")
            response.raise_for_status()
            
            if output_path:
                with open(output_path, 'wb') as f:
                    f.write(response.content)
                print(f"结果已保存到: {output_path}")
            
            return response.json()
        except Exception as e:
            print(f"获取任务结果失败: {str(e)}")
            return None
    
    def list_jobs(self):
        """列出所有任务"""
        try:
            response = self.session.get(f"{self.base_url}/api/v1/jobs")
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"列出任务失败: {str(e)}")
            return None
    
    def delete_job(self, job_id: str):
        """删除任务"""
        try:
            response = self.session.delete(f"{self.base_url}/api/v1/jobs/{job_id}")
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"删除任务失败: {str(e)}")
            return None
    
    def cleanup_old_jobs(self, max_age_hours: int = 24):
        """清理旧任务"""
        try:
            response = self.session.post(
                f"{self.base_url}/api/v1/cleanup",
                params={'max_age_hours': max_age_hours}
            )
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"清理任务失败: {str(e)}")
            return None
    
    def wait_for_completion(self, job_id: str, max_wait: int = 300, interval: int = 2):
        """等待任务完成"""
        start_time = time.time()
        
        while time.time() - start_time < max_wait:
            status = self.get_job_status(job_id)
            
            if not status:
                return None
            
            print(f"任务状态: {status['status']} - {status['message']}")
            
            if status['status'] == 'completed':
                return status
            elif status['status'] == 'failed':
                print(f"任务失败: {status.get('error', '未知错误')}")
                return status
            
            time.sleep(interval)
        
        print(f"等待超时 ({max_wait}秒)")
        return None


def test_workflow(client: PRTFeatureAPIClient, file_path: str, output_dir: str = "./results"):
    """测试完整工作流"""
    print("="*60)
    print("开始测试PRT特征识别API")
    print("="*60)
    
    # 1. 健康检查
    print("\n1. 健康检查...")
    health = client.check_health()
    if health:
        print(f"   服务状态: {health['status']}")
        print(f"   版本: {health['version']}")
        print(f"   主板验证: {health['motherboard_uid_verified']}")
    else:
        print("   健康检查失败!")
        return
    
    # 2. 上传文件
    print(f"\n2. 上传文件: {file_path}")
    job_info = client.upload_file(file_path, enable_topology=True)
    
    if not job_info:
        print("   文件上传失败!")
        return
    
    job_id = job_info['job_id']
    print(f"   任务ID: {job_id}")
    print(f"   初始状态: {job_info['status']}")
    
    # 3. 等待完成
    print("\n3. 等待处理完成...")
    final_status = client.wait_for_completion(job_id, max_wait=300)
    
    if not final_status:
        print("   任务未能完成!")
        return
    
    if final_status['status'] != 'completed':
        print(f"   任务失败: {final_status.get('error', '未知错误')}")
        return
    
    # 4. 获取结果
    print("\n4. 获取处理结果...")
    output_path = Path(output_dir) / f"{job_id}_result.json"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    result = client.get_job_result(job_id, str(output_path))
    
    if result:
        print(f"   ✓ 结果已保存到: {output_path}")
        print(f"   特征数量: {len(result.get('advanced_features', {}))}")
    
    # 5. 列出所有任务
    print("\n5. 列出所有任务...")
    jobs = client.list_jobs()
    if jobs:
        print(f"   当前任务总数: {jobs['total']}")
    
    print("\n" + "="*60)
    print("测试完成!")
    print("="*60)


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="PRT特征识别API客户端测试")
    parser.add_argument(
        '--url',
        type=str,
        default='http://localhost:8000',
        help='API服务地址'
    )
    parser.add_argument(
        '--file',
        type=str,
        help='要上传的PRT/STEP文件路径'
    )
    parser.add_argument(
        '--output',
        type=str,
        default='./results',
        help='结果输出目录'
    )
    parser.add_argument(
        '--action',
        type=str,
        choices=['test', 'health', 'list', 'cleanup'],
        default='test',
        help='执行的操作'
    )
    
    args = parser.parse_args()
    
    # 创建客户端
    client = PRTFeatureAPIClient(args.url)
    
    if args.action == 'health':
        # 健康检查
        health = client.check_health()
        print(json.dumps(health, indent=2, ensure_ascii=False))
        
    elif args.action == 'list':
        # 列出任务
        jobs = client.list_jobs()
        print(json.dumps(jobs, indent=2, ensure_ascii=False))
        
    elif args.action == 'cleanup':
        # 清理任务
        result = client.cleanup_old_jobs()
        print(json.dumps(result, indent=2, ensure_ascii=False))
        
    elif args.action == 'test':
        # 完整测试
        if not args.file:
            print("错误: 测试模式需要指定文件路径 (--file)")
            return
        
        if not Path(args.file).exists():
            print(f"错误: 文件不存在: {args.file}")
            return
        
        test_workflow(client, args.file, args.output)


if __name__ == "__main__":
    main()
