1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
diff --git a/Tensile/Source/lib/source/msgpack/MessagePack.cpp b/Tensile/Source/lib/source/msgpack/MessagePack.cpp
index de97929c..dbc397e0 100644
--- a/tensilelite/src/msgpack/MessagePack.cpp
+++ b/tensilelite/src/msgpack/MessagePack.cpp
@@ -28,6 +28,8 @@
#include <Tensile/msgpack/Loading.hpp>
+#include <zstd.h>
+
#include <fstream>
namespace Tensile
@@ -86,6 +88,34 @@ namespace Tensile
return nullptr;
}
+ // Check if the file is zstd compressed
+ char magic[4];
+ in.read(magic, 4);
+ bool isCompressed = (in.gcount() == 4 && magic[0] == '\x28' && magic[1] == '\xB5' && magic[2] == '\x2F' && magic[3] == '\xFD');
+ // Reset file pointer to the beginning
+ in.seekg(0, std::ios::beg);
+
+ if (isCompressed) {
+ // Decompress zstd file
+ std::vector<char> compressedData((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
+
+ size_t decompressedSize = ZSTD_getFrameContentSize(compressedData.data(), compressedData.size());
+ if (decompressedSize == ZSTD_CONTENTSIZE_ERROR || decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN) {
+ if(Debug::Instance().printDataInit())
+ std::cout << "Error: Unable to determine decompressed size for " << filename << std::endl;
+ return false;
+ }
+
+ std::vector<char> decompressedData(decompressedSize);
+ size_t dSize = ZSTD_decompress(decompressedData.data(), decompressedSize, compressedData.data(), compressedData.size());
+ if (ZSTD_isError(dSize)) {
+ if(Debug::Instance().printDataInit())
+ std::cout << "Error: ZSTD decompression failed for " << filename << std::endl;
+ return false;
+ }
+
+ msgpack::unpack(result, decompressedData.data(), dSize);
+ } else {
msgpack::unpacker unp;
bool finished_parsing;
constexpr size_t buffer_size = 1 << 19;
@@ -109,6 +139,7 @@ namespace Tensile
return nullptr;
}
+ }
}
catch(std::runtime_error const& exc)
{
|