#!/usr/bin/env python # -*- coding: utf-8 -*- ''' @File : funcs.py @Ver : 1.0 @Desc : None @Author : Claude,refined by YangTianxi @Time : 2025/08/09 05:20:55 @Dev Software: Vscode ''' import torch import collections import numpy as np import cumm.tensorview as tv import logging try: # OpenPCDet imports from pcdet.models.backbones_3d import VoxelBackBone8x, VoxelResBackBone8x from spconv.pytorch import SparseSequential from spconv.pytorch import conv OPENPCDET_AVAILABLE = True except ImportError: # Fallback to original imports try: from det3d.models.backbones.scn import SpMiddleResNetFHD, SparseBasicBlock from spconv.pytorch import SparseSequential from spconv.pytorch import conv except ImportError: print("Warning: Neither OpenPCDet nor det3d imports are available") OPENPCDET_AVAILABLE = False def make_new_repr(old_repr): def new_repr(self): s = old_repr(self) if hasattr(self, 'act_type') and self.act_type is not None: p = s.rfind(")") s = s[:p] + f', act={self.act_type}' + s[p:] return s return new_repr # setup repr function, add activation if 'conv' in locals(): conv.SparseConvolution.__repr__ = make_new_repr(conv.SparseConvolution.__repr__) def fuse_bn_weights(conv_w_OKI, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): NDim = conv_w_OKI.ndim - 2 permute = [0, NDim+1] + [i+1 for i in range(NDim)] conv_w_OIK = conv_w_OKI.permute(*permute) if conv_b is None: conv_b = torch.zeros_like(bn_rm) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) conv_w_OIK = conv_w_OIK * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w_OIK.shape) - 1)) conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b permute = [0,] + [i+2 for i in range(NDim)] + [1,] conv_w_OKI = conv_w_OIK.permute(*permute).contiguous() return torch.nn.Parameter(conv_w_OKI), torch.nn.Parameter(conv_b) def fuse_bn(conv, bn): """ Fuse batch normalization into convolution layer """ assert(not (conv.training or bn.training)), "Fusion only for eval!" conv.weight, conv.bias = fuse_bn_weights( conv.weight, conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias ) def fuse_sparse_sequential_layers(sequential_module, module_name=""): """ 递归融合SparseSequential中的所有conv+bn+relu组合 """ fused_count = 0 # 如果不是Sequential类型,直接返回 if not hasattr(sequential_module, '__len__') or len(sequential_module) == 0: return fused_count # 转换为列表进行处理 modules = list(sequential_module.children()) new_modules = [] i = 0 while i < len(modules): current_module = modules[i] # 检查是否为conv+bn组合 if (i < len(modules) - 1 and hasattr(current_module, 'weight') and # 确保是卷积层 ('Conv' in type(current_module).__name__ or hasattr(current_module, 'kernel_size')) and 'BatchNorm' in type(modules[i + 1]).__name__): try: # 融合conv和bn fuse_bn(current_module, modules[i + 1]) # 检查是否有后续的ReLU if (i < len(modules) - 2 and ('ReLU' in type(modules[i + 2]).__name__ or 'Activation' in type(modules[i + 2]).__name__)): current_module.act_type = tv.gemm.Activation.ReLU i += 3 # 跳过bn和relu logging.debug(f"Fused {module_name}: conv+bn+relu at index {i-3}") else: # 只有conv+bn,检查是否需要添加激活 current_module.act_type = tv.gemm.Activation.ReLU i += 2 # 跳过bn logging.debug(f"Fused {module_name}: conv+bn at index {i-2}") new_modules.append(current_module) fused_count += 1 except Exception as e: logging.warning(f"Failed to fuse at {module_name}[{i}]: {e}") new_modules.extend([current_module, modules[i + 1]]) i += 2 else: new_modules.append(current_module) i += 1 # 重建Sequential如果有融合发生 if fused_count > 0: # 清空原Sequential并添加新模块 sequential_module.clear() for module in new_modules: sequential_module.append(module) logging.info(f"Rebuilt {module_name} with {len(new_modules)} modules (fused {fused_count} layers)") return fused_count def recursive_fusion(module, module_name="", depth=0): """ 递归遍历所有模块进行融合 """ total_fused = 0 # 如果是SparseSequential,直接融合 if hasattr(module, '__len__') and hasattr(module, 'append'): fused = fuse_sparse_sequential_layers(module, module_name) total_fused += fused if fused > 0: logging.info(f"{' ' * depth}Fused {fused} layers in {module_name}") # 递归处理所有子模块 for name, child in module.named_children(): child_name = f"{module_name}.{name}" if module_name else name child_fused = recursive_fusion(child, child_name, depth + 1) total_fused += child_fused return total_fused def layer_fusion_comprehensive(model, verbose: bool = True): """ 全面的层融合函数,使用递归方式处理所有可能的结构 """ if verbose: logging.basicConfig(level=logging.DEBUG, force=True) else: logging.basicConfig(level=logging.WARNING, force=True) model_type = type(model).__name__ logging.info(f"Starting comprehensive layer fusion for model type: {model_type}") # 确保模型处于评估模式 model.eval() try: # 递归融合所有模块 total_fused = recursive_fusion(model, model_type) # 特殊处理:直接查找所有conv+bn对 additional_fused = direct_conv_bn_fusion(model) total_fused += additional_fused logging.info(f"Comprehensive fusion completed: {total_fused} total fusions") except Exception as e: logging.error(f"Layer fusion failed: {e}") raise return model def direct_conv_bn_fusion(model): """ 直接查找并融合所有conv+bn对 """ fused_count = 0 modules_list = list(model.named_modules()) for i, (name, module) in enumerate(modules_list): # 查找卷积层 if (hasattr(module, 'weight') and hasattr(module, 'kernel_size') and ('Conv' in type(module).__name__ or 'SparseConv' in type(module).__name__)): # 在后续模块中查找对应的BN层 conv_name_parts = name.split('.') # 查找同级的下一个模块或相邻模块 for j in range(i + 1, min(i + 3, len(modules_list))): # 只检查接下来的2个模块 next_name, next_module = modules_list[j] if 'BatchNorm' in type(next_module).__name__: # 检查是否是配对的BN层 next_name_parts = next_name.split('.') # 判断是否是同一个父模块下的相邻层 if (len(conv_name_parts) == len(next_name_parts) and conv_name_parts[:-1] == next_name_parts[:-1]): try: # 尝试融合 fuse_bn(module, next_module) module.act_type = tv.gemm.Activation.ReLU # 从父模块中删除BN层 parent_module = model for part in next_name_parts[:-1]: parent_module = getattr(parent_module, part) if hasattr(parent_module, next_name_parts[-1]): # 将BN层替换为Identity或删除 setattr(parent_module, next_name_parts[-1], torch.nn.Identity()) fused_count += 1 logging.info(f"Direct fusion: {name} + {next_name}") except Exception as e: logging.warning(f"Failed direct fusion {name} + {next_name}: {e}") break # 找到BN层后停止搜索 return fused_count def layer_fusion(model, verbose: bool = True): """ 统一的层融合函数入口 """ return layer_fusion_comprehensive(model, verbose) def validate_fusion(model): """ 验证融合结果,检查是否还存在未融合的BatchNorm层 """ unfused_bn_count = 0 unfused_bn_names = [] for name, module in model.named_modules(): if 'BatchNorm' in type(module).__name__: # 检查是否是Identity(已被替换的BN) if not isinstance(module, torch.nn.Identity): unfused_bn_count += 1 unfused_bn_names.append(name) logging.warning(f"Unfused BatchNorm found: {name}") if unfused_bn_count == 0: logging.info("✓ All BatchNorm layers successfully fused") else: logging.warning(f"⚠ {unfused_bn_count} BatchNorm layers remain unfused") logging.info("Unfused BN layers: " + ", ".join(unfused_bn_names)) return unfused_bn_count == 0 #''' def load_scn_checkpoint(model, file, model_type="VoxelBackBone8x"): """ Load OpenPCDet checkpoint for various backbone types """ device = next(model.parameters()).device ckpt = torch.load(file, map_location=device) # Handle different checkpoint formats if "model_state" in ckpt: state_dict = ckpt["model_state"] elif "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt new_ckpt = collections.OrderedDict() # Extract backbone weights based on model type backbone_prefixes = ["backbone_3d.", "module.backbone_3d.", "backbone."] for key, val in state_dict.items(): new_key = key # Remove common prefixes for prefix in backbone_prefixes: if key.startswith(prefix): new_key = key[len(prefix):] break # Handle weight dimension adjustments if needed if val.ndim == 5: # For 3D convolutions, might need permutation based on framework differences # Keep original order for OpenPCDet compatibility val = val.permute(0, 1, 2, 3, 4) # Usually no change needed for OpenPCDet # Only include backbone parameters if any(layer_name in new_key for layer_name in ['conv', 'bn', 'downsample']): new_ckpt[new_key] = val try: model.load_state_dict(new_ckpt, strict=False) total_params = sum(p.numel() for p in new_ckpt.values()) print(f"Successfully loaded {len(new_ckpt)} tensors from checkpoint, total parameters: {total_params}") except Exception as e: print(f"Warning: Could not load some parameters: {e}") # Try to load parameters one by one model_dict = model.state_dict() loaded_keys = [] for key, val in new_ckpt.items(): if key in model_dict and model_dict[key].shape == val.shape: model_dict[key] = val loaded_keys.append(key) model.load_state_dict(model_dict) print(f"Loaded {len(loaded_keys)} parameters individually") return model #''' def load_scn_checkpointV2(model, file, model_type="VoxelBackBone8x"): """ Load backbone weights only if parameter count matches. Args: model (torch.nn.Module): The target model (e.g., VoxelBackBone8x()). file (str): Path to the checkpoint file. model_type (str): Name/prefix of the backbone module in checkpoint (e.g., "backbone_3d"). """ device = next(model.parameters()).device ckpt = torch.load(file, map_location=device) # Try different key names for checkpoint if "model_state" in ckpt: state_dict = ckpt["model_state"] elif "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt # 1. 统计模型中的可训练参数数量 model_param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"✅ Current model has {model_param_count:,} trainable parameters.") # 2. 提取 checkpoint 中与 model_type 相关的参数 # 尝试多种可能的模块名前缀 possible_prefixes = [ f"{model_type}.", f"module.{model_type}.", f"{model_type}_3d.", "backbone_3d.", "module.backbone_3d.", "voxel_backbone.", "VoxelBackBone8x.", ] new_ckpt = collections.OrderedDict() matched_prefix = None # 找到匹配的前缀 for prefix in possible_prefixes: temp_ckpt = collections.OrderedDict() for key, val in state_dict.items(): if key.startswith(prefix): new_key = key[len(prefix):] temp_ckpt[new_key] = val if temp_ckpt: # 如果找到了匹配的参数 new_ckpt = temp_ckpt matched_prefix = prefix break if not new_ckpt: print(f"❌ No parameters found for any of these prefixes: {possible_prefixes}") return model print(f"📦 Found parameters with prefix: '{matched_prefix}'") # 3. 统计提取出的可训练参数数量(排除BatchNorm的running stats等) ckpt_trainable_count = 0 ckpt_total_count = 0 non_trainable_suffixes = ['running_mean', 'running_var', 'num_batches_tracked'] for key, val in new_ckpt.items(): param_count = val.numel() ckpt_total_count += param_count # 检查是否为不可训练参数 is_non_trainable = any(suffix in key for suffix in non_trainable_suffixes) if not is_non_trainable: ckpt_trainable_count += param_count print(f"📦 Checkpoint contains {ckpt_total_count:,} total parameters for '{model_type}'.") print(f"📦 Checkpoint contains {ckpt_trainable_count:,} trainable parameters for '{model_type}'.") # 4. 判断可训练参数数量是否一致 if ckpt_trainable_count != model_param_count: print(f"❌ Trainable parameter count mismatch:") print(f" Model: {model_param_count:,}") print(f" Checkpoint: {ckpt_trainable_count:,}") print(f" Skip loading.") return model print(f"✅ Trainable parameter counts match: {model_param_count:,}") # 5. 加载模型参数并验证加载完整性 try: # 尝试宽松模式加载以获取详细的missing/unexpected信息 missing_keys, unexpected_keys = model.load_state_dict(new_ckpt, strict=False) # 统计成功加载的参数数量 model_dict = model.state_dict() loaded_param_count = 0 loaded_keys_count = 0 for key, val in new_ckpt.items(): if key in model_dict and model_dict[key].shape == val.shape: # 排除不可训练参数 is_non_trainable = any(suffix in key for suffix in non_trainable_suffixes) if not is_non_trainable: loaded_param_count += val.numel() loaded_keys_count += 1 if missing_keys: print(f"⚠️ Missing keys in checkpoint: {len(missing_keys)} keys") if len(missing_keys) <= 10: # 如果缺失的键不多,显示详细信息 for key in missing_keys: print(f" - {key}") else: print(f" First 10: {missing_keys[:10]}") if unexpected_keys: print(f"⚠️ Unexpected keys in checkpoint: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 10: for key in unexpected_keys: print(f" - {key}") else: print(f" First 10: {unexpected_keys[:10]}") if not missing_keys and not unexpected_keys: print(f"✅ Successfully loaded all parameters from checkpoint.") print(f"🎯 Loaded {loaded_param_count:,} trainable parameters ({loaded_keys_count} parameter tensors).") # 验证加载的参数数量是否等于模型的可训练参数数量 if loaded_param_count == model_param_count: print(f"✅ Parameter count verification passed: {loaded_param_count:,}") else: print(f"⚠️ Parameter count mismatch after loading:") print(f" Expected: {model_param_count:,}") print(f" Actually loaded: {loaded_param_count:,}") else: print(f"✅ Successfully loaded checkpoint with some missing/unexpected keys.") print(f"🎯 Loaded {loaded_param_count:,} trainable parameters ({loaded_keys_count} parameter tensors).") # 验证加载的参数数量 if loaded_param_count == model_param_count: print(f"✅ Parameter count verification passed despite missing/unexpected keys.") else: print(f"❌ Parameter count verification failed:") print(f" Expected: {model_param_count:,}") print(f" Actually loaded: {loaded_param_count:,}") print(f" Missing: {model_param_count - loaded_param_count:,} parameters") # 如果参数不匹配,可以选择抛出异常或返回原模型 raise ValueError(f"Incomplete parameter loading: expected {model_param_count}, got {loaded_param_count}") except Exception as e: print(f"⚠️ Error during loading: {e}") print(f"🔧 Attempting individual parameter loading...") # 备用方案:逐个加载兼容的参数 model_dict = model.state_dict() loaded_keys = [] shape_mismatch_keys = [] loaded_param_count = 0 for key, val in new_ckpt.items(): if key in model_dict: if model_dict[key].shape == val.shape: model_dict[key] = val loaded_keys.append(key) # 统计可训练参数 is_non_trainable = any(suffix in key for suffix in non_trainable_suffixes) if not is_non_trainable: loaded_param_count += val.numel() else: shape_mismatch_keys.append((key, model_dict[key].shape, val.shape)) model.load_state_dict(model_dict) print(f"🔧 Successfully loaded {len(loaded_keys)}/{len(new_ckpt)} parameter tensors.") print(f"🎯 Loaded {loaded_param_count:,} trainable parameters.") # 验证备用加载的参数数量 if loaded_param_count == model_param_count: print(f"✅ Parameter count verification passed: {loaded_param_count:,}") else: print(f"❌ Parameter count verification failed:") print(f" Expected: {model_param_count:,}") print(f" Actually loaded: {loaded_param_count:,}") print(f" Missing: {model_param_count - loaded_param_count:,} parameters") if shape_mismatch_keys: print(f"⚠️ Shape mismatches for {len(shape_mismatch_keys)} parameters:") for key, model_shape, ckpt_shape in shape_mismatch_keys[:5]: # 只显示前5个 print(f" {key}: model{model_shape} vs checkpoint{ckpt_shape}") return model ''' def save_tensor(tensor, filename): """ Save tensor to file for verification """ if isinstance(tensor, torch.Tensor): data = tensor.detach().cpu().numpy() else: data = np.array(tensor) with open(filename, 'wb') as f: np.save(f, data) print(f"Saved tensor with shape {data.shape} to {filename}") ''' # This function stores a file that can be very easily loaded and used by c++ def save_tensor(tensor, file): if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().data.numpy() elif not isinstance(tensor, np.ndarray): tensor = np.array(tensor) dtype_map = {"float32" : 0, "float16" : 1, "int32" : 2, "int64" : 3} if str(tensor.dtype) not in dtype_map: raise RuntimeError(f"Unsupport dtype {tensor.dtype}") magic_number = 0x33ff1101 with open(file, "wb") as f: head = np.array([magic_number, tensor.ndim, dtype_map[str(tensor.dtype)]], dtype=np.int32).tobytes() f.write(head) dims = np.array(tensor.shape, dtype=np.int32).tobytes() f.write(dims) data = tensor.tobytes() f.write(data) # This function stores a file that can be very easily loaded and used by c++ def load_tensor(file): dtype_for_integer_mapping = {0: np.float32, 1: np.float16, 2: np.int32, 3: np.int64} dtype_size_mapping = {np.float32 : 4, np.float16 : 2, np.int32 : 4, np.int64 : 8} with open(file, "rb") as f: magic_number, ndim, dtype_integer = np.frombuffer(f.read(12), dtype=np.int32) if dtype_integer not in dtype_for_integer_mapping: raise RuntimeError(f"Can not find match dtype for index {dtype_integer}") dtype = dtype_for_integer_mapping[dtype_integer] magic_number_std = 0x33ff1101 assert magic_number == magic_number_std, f"this file is not tensor file" dims = np.frombuffer(f.read(ndim * 4), dtype=np.int32) volumn = np.cumprod(dims)[-1] data = np.frombuffer(f.read(volumn * dtype_size_mapping[dtype]), dtype=dtype).reshape(*dims) return data def print_model_structure(model, max_depth=3): """ 打印模型结构以便调试 """ def print_module(module, name="", depth=0): if depth > max_depth: return indent = " " * depth module_type = type(module).__name__ if hasattr(module, '__len__') and len(module) > 0: print(f"{indent}{name}: {module_type}({len(module)} modules)") for i, child in enumerate(module.children()): print_module(child, f"[{i}]", depth + 1) elif list(module.children()): print(f"{indent}{name}: {module_type}") for child_name, child in module.named_children(): print_module(child, child_name, depth + 1) else: extra_info = "" if hasattr(module, 'act_type') and module.act_type is not None: extra_info = f" (act={module.act_type})" print(f"{indent}{name}: {module_type}{extra_info}") print("Model Structure:") print_module(model, type(model).__name__)