Source code for xuance.common.statistic_tools

from xuance.common import Union, Sequence
import numpy as np


[docs] def mpi_mean(x, axis=0, comm=None, keepdims=False): """ Compute the mean across all MPI processes. Parameters: x (array-like): Input array. axis (int): Axis along which the mean is computed. comm (MPI.Comm, optional): MPI communicator. Defaults to MPI.COMM_WORLD. keepdims (bool): Whether to keep the dimensions of the result. Defaults to False. Returns: tuple: A tuple containing: - mean (ndarray): Mean of the input array. - count (int): Total count used for the mean computation. """ x = np.asarray(x) assert x.ndim > 0 if comm is None: from mpi4py import MPI comm = MPI.COMM_WORLD xsum = x.sum(axis=axis, keepdims=keepdims) n = xsum.size localsum = np.zeros(n + 1, x.dtype) localsum[:n] = xsum.ravel() localsum[n] = x.shape[axis] globalsum = np.zeros_like(localsum) comm.Allreduce(localsum, globalsum, op=MPI.SUM) return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
[docs] def mpi_moments(x, axis=0, comm=None, keepdims=False): """ Compute the mean and standard deviation across MPI processes. Parameters: x (array-like): Input array for which to calculate moments. axis (int, optional): Axis along which to compute the mean and standard deviation. Defaults to 0. comm (MPI.Comm, optional): MPI communicator for distributed computation. If None, defaults to MPI.COMM_WORLD. keepdims (bool, optional): Whether to retain reduced dimensions in the output. Defaults to False. Returns: tuple: A tuple containing: - mean (ndarray): Mean of the input array. - std (ndarray): Standard deviation of the input array. - count (int): Total count used for the computation. """ x = np.asarray(x) assert x.ndim > 0 mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) sqdiffs = np.square(x - mean) meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) assert count1 == count std = np.sqrt(meansqdiff) if not keepdims: newshape = mean.shape[:axis] + mean.shape[axis + 1:] mean = mean.reshape(newshape) std = std.reshape(newshape) return mean, std, count
[docs] class RunningMeanStd(object): """ Maintains a running mean and standard deviation. Attributes: shape (Union[Sequence[int], dict]): Shape of the input data. epsilon (float): Small value to prevent division by zero. comm (MPI.Comm): MPI communicator for distributed computation. use_mpi (bool): Whether to use MPI for computation. """ def __init__(self, shape: Union[Sequence[int], dict], epsilon=1e-4, comm=None, use_mpi=False): """ Initialize the RunningMeanStd object. Parameters: shape (Union[Sequence[int], dict]): Shape of the input data. epsilon (float): Small value to prevent division by zero. comm (MPI.Comm, optional): MPI communicator. Defaults to MPI.COMM_WORLD. use_mpi (bool): Whether to use MPI for computation. Defaults to False. """ self.shape = shape if isinstance(shape, dict): self.mean = {key: np.zeros(shape[key], np.float32) for key in shape.keys()} self.var = {key: np.ones(shape[key], np.float32) for key in shape.keys()} self.count = {key: epsilon for key in shape.keys()} else: self.mean = np.zeros(shape, np.float32) self.var = np.ones(shape, np.float32) self.count = epsilon self.use_mpi = use_mpi # if comm is None: # from mpi4py import MPI # comm = MPI.COMM_WORLD self.comm = comm @property def std(self): """ Compute the standard deviation. Returns: Union[dict, ndarray]: The standard deviation of the running statistics. """ if isinstance(self.shape, dict): return {key: np.sqrt(self.var[key]) for key in self.shape.keys()} else: return np.sqrt(self.var)
[docs] def update(self, x): """ Update the running mean and standard deviation with new data. Parameters: x (Union[dict, ndarray]): New data to update the statistics. """ if isinstance(x, dict): batch_means = {} batch_vars = {} batch_counts = {} for key in self.shape.keys(): if self.use_mpi: batch_mean, batch_std, batch_count = mpi_moments(x[key], axis=0, comm=self.comm) else: batch_mean, batch_std, batch_count = np.mean(x[key], axis=0), np.std(x[key], axis=0), x[key].shape[ 0] batch_means[key] = batch_mean batch_vars[key] = np.square(batch_std) batch_counts[key] = batch_count self.update_from_moments(batch_means, batch_vars, batch_counts) else: if self.use_mpi: batch_mean, batch_std, batch_count = mpi_moments(x, axis=0, comm=self.comm) else: batch_mean, batch_std, batch_count = np.mean(x, axis=0), np.std(x, axis=0), x.shape[0] batch_var = np.square(batch_std) self.update_from_moments(batch_mean, batch_var, batch_count)
[docs] def update_from_moments(self, batch_mean, batch_var, batch_count): """ Update the running mean, variance, and count using new statistics. This method updates the current statistics by combining them with batch-level statistics, supporting both dictionary and array inputs. Parameters: batch_mean (Union[dict, ndarray]): Mean of the new batch. batch_var (Union[dict, ndarray]): Variance of the new batch. batch_count (Union[dict, int]): Number of samples in the new batch. Updates: self.mean (Union[dict, ndarray]): Updated running mean. self.var (Union[dict, ndarray]): Updated running variance. self.count (Union[dict, float]): Updated sample count. """ if isinstance(batch_mean, dict): for key in self.shape: delta = batch_mean[key] - self.mean[key] tot_count = self.count[key] + batch_count[key] new_mean = self.mean[key] + delta * batch_count[key] / tot_count m_a = self.var[key] * (self.count[key]) m_b = batch_var[key] * (batch_count[key]) M2 = m_a + m_b + np.square(delta) * self.count[key] * batch_count[key] / tot_count new_var = M2 / tot_count self.mean[key] = new_mean self.var[key] = new_var self.count[key] = tot_count else: delta = batch_mean - self.mean tot_count = self.count + batch_count new_mean = self.mean + delta * batch_count / tot_count m_a = self.var * (self.count) m_b = batch_var * (batch_count) M2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count new_var = M2 / tot_count self.mean = new_mean self.var = new_var self.count = tot_count