import numpy as np def distance_matrix(X: np.ndarray, Y: np.ndarray): # compute the distance matrix between two sets of samples x2 = np.sum(X**2, axis=1, keepdims=True) y2 = np.sum(Y**2, axis=1, keepdims=True) xy = X.dot(Y.T) return np.sqrt(x2 - 2 * xy + y2.T)