DDC 0.15.0
Loading...
Searching...
No Matches
save_npy.cpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#include <bit>
6#include <cstddef>
7#include <cstdint>
8#include <filesystem>
9#include <fstream>
10#include <numeric>
11#include <stdexcept>
12#include <string>
13#include <utility>
14#include <vector>
15
16#include "save_npy.hpp"
17
18namespace ddc::detail {
19
20NpyByteOrder get_byte_order(std::size_t const itemsize) noexcept
21{
22 if (itemsize == 1) {
23 return NpyByteOrder::not_applicable;
24 }
25
26 if (std::endian::native == std::endian::little) {
27 return NpyByteOrder::little_endian;
28 }
29
30 if (std::endian::native == std::endian::big) {
31 return NpyByteOrder::big_endian;
32 }
33
34 return NpyByteOrder::not_applicable;
35}
36
37void write_le(std::ostream& os, std::uint16_t const value_u16)
38{
39 constexpr unsigned int mask = 0xFFU;
40
41 unsigned int const value_u = value_u16;
42
43 std::array<char, 2> bytes;
44 bytes[0] = value_u & mask;
45 bytes[1] = (value_u >> 8U) & mask;
46
47 os.write(bytes.data(), sizeof(value_u16));
48}
49
50std::string NpyDtype::str() const
51{
52 return std::string(1, static_cast<char>(byte_order)) + static_cast<char>(kind)
53 + std::to_string(itemsize);
54}
55
56// See specification at https://numpy.org/neps/nep-0001-npy-format.html#format-specification-version-1-0
57void save_npy(std::ostream& os, NpyArrayView const& view)
58{
59 // Build shape string: (d0, d1, ..., dN,)
60 std::string shape_str = "(";
61 for (std::size_t const ext : view.shape) {
62 shape_str += std::to_string(ext);
63 shape_str += ", ";
64 }
65 shape_str += ")";
66
67 std::string const header_dict
68 = std::string("{'descr': '") + view.dtype.str() + "', 'fortran_order': "
69 + (view.fortran_order ? "True" : "False") + ", 'shape': " + shape_str + ", }";
70
71 // Pad header to a multiple of 16
72 std::size_t const non_padded_header_len = header_dict.size() + 1;
73 // magic(6) + major(1) + minor(1) + header_len(2) + header
74 std::size_t const alignment = 16;
75 std::size_t const remainder = (6 + 1 + 1 + 2 + non_padded_header_len) % alignment;
76 std::size_t const padding = (alignment - remainder) % alignment;
77 if (!std::in_range<std::uint16_t>(non_padded_header_len + padding)) {
78 throw std::runtime_error("save_npy: header too large for npy v1.0.");
79 }
80 auto const header_len = static_cast<std::uint16_t>(non_padded_header_len + padding);
81
82 // magic string
83 os.write("\x93NUMPY", 6);
84 // major version
85 os.put(1);
86 // minor version
87 os.put(0);
88 // header length in little-endian
89 write_le(os, header_len);
90 // header + padding + newline
91 os.write(header_dict.data(), header_dict.size());
92 os.write(" ", padding);
93 os.put('\n');
94
95 // Raw data
96 std::size_t const n_elems
97 = std::accumulate(view.shape.begin(), view.shape.end(), 1ULL, std::multiplies<> {});
98 os.write(reinterpret_cast<char const*>(view.data), n_elems * view.dtype.itemsize);
99}
100
101void save_npy(std::filesystem::path const& filename, NpyArrayView const& view)
102{
103 std::ofstream file(filename, std::ios::binary);
104 file.exceptions(std::ios::failbit | std::ios::badbit);
105
106 save_npy(file, view);
107}
108
109} // namespace ddc::detail
The top-level namespace of DDC.