Source code for xuance.torch.utils.device

import os
import platform
import torch
import xuance


[docs] def set_device(expected_device: str): """ Set the computing device for a given deep learning framework. Args: dl_toolbox (str): The deep learning framework to use. Options: "torch", "tensorflow", "mindspore". expected_device (str): The desired computing device. Options: "cuda", "GPU", "gpu", "Ascend", "cpu", "CPU. Returns: str: The assigned computing device, which may differ from `expected_device` if the requested device is unavailable. """ device = expected_device if ("cuda" in expected_device) or (expected_device.upper() == "GPU"): if not torch.cuda.is_available(): print("WARNING: CUDA for PyTorch is not available, set the device as 'cpu'.") device = "cpu" elif expected_device.upper() == "GPU": print(f"WARNING: the device name {expected_device} is invalid, set the device as 'cuda:0'.") device = "cuda:0" elif expected_device.upper() == "CPU": device = "cpu" else: print(f"WARNING: the device name {expected_device} is invalid, set the device as 'cpu'.") device = "cpu" return device
[docs] def collect_device_info( rank: int = 0, agent=None, ) -> dict: """Collect runtime device / system info for reproducibility. Returns a JSON-serializable dict. """ info = { "Platform": platform.platform(), "CUDA_Available": bool(torch.cuda.is_available()), "Python": platform.python_version(), "XuanCe": xuance.__version__, "PyTorch": getattr(torch, "__version__", "unknown"), "PID": os.getpid(), "Rank": rank, } # Try to use agent's real device (most reliable). device = None try: # Find a parameter device if possible. if agent is not None: # Try common attribute names in your codebase obj = getattr(agent, "policy", None) if obj is not None and hasattr(obj, "parameters"): device = next(obj.parameters()).device # Fallback: config.device if no parameter found if device is None: device = torch.device(str(agent.config.device)) if device is None: device = torch.device("cpu") info["device"] = str(device) if device.type == "cuda": idx = device.index if device.index is not None else 0 info.update({ "gpu_index": int(idx), "gpu_name": torch.cuda.get_device_name(idx), "cuda_version": getattr(torch.version, "cuda", None), }) # Optional: driver / capability try: cap = torch.cuda.get_device_capability(idx) info["gpu_capability"] = f"{cap[0]}.{cap[1]}" except Exception: pass else: info.update({ "cpu_arch": platform.processor(), }) except Exception as e: # If torch isn't available or something fails, keep it minimal but valid. info["device"] = str(getattr(getattr(agent, 'config', None), "device", "unknown")) info["device_info_error"] = repr(e) return info