"""Utilities for importing tensor data."""
# Copyright 2025 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.
from __future__ import annotations
import os
from typing import TextIO
import numpy as np
from scipy.io import loadmat
import pyttb as ttb
from pyttb.pyttb_utils import to_memory_order
[docs]
def import_data(
filename: str, index_base: int = 1
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Import tensor data.
Parameters
----------
filename:
File to import.
index_base:
Index basing allows interoperability (Primarily between python and MATLAB).
"""
# Check if file exists
if not os.path.isfile(filename):
assert False, f"File path {filename} does not exist."
# import
with open(filename) as fp:
# tensor type should be on the first line
# valid: tensor, sptensor, matrix, ktensor
data_type = import_type(fp)
if data_type not in ["tensor", "sptensor", "matrix", "ktensor"]:
assert False, f"Invalid data type found: {data_type}"
if data_type == "tensor":
shape = import_shape(fp)
data = import_array(fp, np.prod(shape))
return ttb.tensor(data, shape, copy=False)
if data_type == "sptensor":
shape = import_shape(fp)
nz = import_nnz(fp)
subs, vals = import_sparse_array(fp, len(shape), nz, index_base)
return ttb.sptensor(subs, vals, shape)
if data_type == "matrix":
shape = import_shape(fp)
mat = import_array(fp, np.prod(shape))
mat = np.reshape(mat, np.array(shape))
return mat
if data_type == "ktensor":
shape = import_shape(fp)
r = import_rank(fp)
weights = import_array(fp, r)
factor_matrices = []
for _ in range(len(shape)):
fp.readline().strip() # Skip factor type
fac_shape = import_shape(fp)
fac = import_array(fp, np.prod(fac_shape))
fac = to_memory_order(np.reshape(fac, np.array(fac_shape)), order="F")
factor_matrices.append(fac)
return ttb.ktensor(factor_matrices, weights, copy=False)
raise ValueError("Failed to load tensor data") # pragma: no cover
def import_data_bin(
filename: str,
index_base: int = 1,
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Import tensor-related data from a binary file."""
def load_bin_data(filename: str):
npzfile = np.load(filename, allow_pickle=False)
return {
"header": npzfile["header"][0],
"data": npzfile.get("data"),
"shape": tuple(npzfile["shape"]) if "shape" in npzfile else None,
"subs": npzfile.get("subs"),
"vals": npzfile.get("vals"),
"num_factor_matrices": int(npzfile["num_factor_matrices"])
if "num_factor_matrices" in npzfile
else None,
"factor_matrices": [
npzfile[f"factor_matrix_{i}"]
for i in range(int(npzfile["num_factor_matrices"]))
]
if "num_factor_matrices" in npzfile
else None,
"weights": npzfile.get("weights"),
}
return _import_tensor_data(filename, index_base, load_bin_data)
def import_data_mat(
filename: str,
index_base: int = 1,
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Import tensor-related data from a MATLAB file."""
def load_mat_data(filename: str):
mat_data = loadmat(filename)
header = mat_data["header"][0]
return {
"header": header.split()[0],
"data": mat_data.get("data"),
"shape": tuple(mat_data["shape"][0]) if "shape" in mat_data else None,
"subs": mat_data.get("subs"),
"vals": mat_data.get("vals"),
"num_factor_matrices": int(mat_data["num_factor_matrices"])
if "num_factor_matrices" in mat_data
else None,
"factor_matrices": [
mat_data[f"factor_matrix_{i}"]
for i in range(int(mat_data["num_factor_matrices"]))
]
if "num_factor_matrices" in mat_data
else None,
"weights": mat_data.get("weights").flatten()
if "weights" in mat_data
else None,
}
return _import_tensor_data(filename, index_base, load_mat_data)
def _import_tensor_data(
filename: str,
index_base: int,
data_loader,
) -> ttb.sptensor | ttb.ktensor | ttb.tensor | np.ndarray:
"""Generalized function to import tensor data from different file formats.
Parameters
----------
filename:
File to import.
index_base:
Index basing allows interoperability (Primarily between python and MATLAB).
data_loader:
Function that loads and structures the data from the file.
"""
# Check if file exists
if not os.path.isfile(filename):
raise FileNotFoundError(f"File path {filename} does not exist.")
loaded_data = data_loader(filename)
data_type = loaded_data["header"]
if data_type not in ["tensor", "sptensor", "matrix", "ktensor"]:
raise ValueError(f"Invalid data type found: '{data_type}'")
if data_type == "tensor":
data = loaded_data["data"]
return ttb.tensor(data)
elif data_type == "sptensor":
shape = loaded_data["shape"]
subs = loaded_data["subs"] - index_base
vals = loaded_data["vals"]
return ttb.sptensor(subs, vals, shape)
elif data_type == "matrix":
data = loaded_data["data"]
return data
elif data_type == "ktensor":
factor_matrices = loaded_data["factor_matrices"]
weights = loaded_data["weights"]
return ttb.ktensor(factor_matrices, weights)
raise ValueError(f"Invalid data type found: {data_type}")
def import_type(fp: TextIO) -> str:
"""Extract IO data type."""
return fp.readline().strip().split(" ")[0]
def import_shape(fp: TextIO) -> tuple[int, ...]:
"""Extract the shape of something from a file."""
n = int(fp.readline().strip().split(" ")[0])
shape = [int(d) for d in fp.readline().strip().split(" ")]
if len(shape) != n:
assert False, "Imported dimensions are not of expected size"
return tuple(shape)
def import_nnz(fp: TextIO) -> int:
"""Extract the number of non-zeros of something from a file."""
return int(fp.readline().strip().split(" ")[0])
def import_rank(fp: TextIO) -> int:
"""Extract the rank of something from a file."""
return int(fp.readline().strip().split(" ")[0])
def import_sparse_array(
fp: TextIO, n: int, nz: int, index_base: int = 1
) -> tuple[np.ndarray, np.ndarray]:
"""Extract sparse data subs and vals from coordinate format data."""
subs = np.zeros((nz, n), dtype="int64")
vals = np.zeros((nz, 1))
for k in range(nz):
line = fp.readline().strip().split(" ")
subs[k, :] = [np.int64(i) - index_base for i in line[:-1]]
vals[k, 0] = line[-1]
return subs, vals
def import_array(fp: TextIO, n: int | np.integer) -> np.ndarray:
"""Extract numpy array from file."""
return np.fromfile(fp, count=n, sep=" ")