summaryrefslogtreecommitdiff
path: root/pkgs/development/rocm-modules/hipblaslt/messagepack-compression-support.patch
blob: ace6b2b728ad21d97d8a55e326760fd2b9f6e800 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
diff --git a/Tensile/Source/lib/source/msgpack/MessagePack.cpp b/Tensile/Source/lib/source/msgpack/MessagePack.cpp
index de97929c..dbc397e0 100644
--- a/tensilelite/src/msgpack/MessagePack.cpp
+++ b/tensilelite/src/msgpack/MessagePack.cpp
@@ -28,6 +28,8 @@
 
 #include <Tensile/msgpack/Loading.hpp>
 
+#include <zstd.h>
+
 #include <fstream>
 
 namespace Tensile
@@ -86,6 +88,34 @@ namespace Tensile
                 return nullptr;
             }
 
+            // Check if the file is zstd compressed
+            char magic[4];
+            in.read(magic, 4);
+            bool isCompressed = (in.gcount() == 4 && magic[0] == '\x28' && magic[1] == '\xB5' && magic[2] == '\x2F' && magic[3] == '\xFD');
+            // Reset file pointer to the beginning
+            in.seekg(0, std::ios::beg);
+
+            if (isCompressed) {
+                // Decompress zstd file
+                std::vector<char> compressedData((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
+
+                size_t decompressedSize = ZSTD_getFrameContentSize(compressedData.data(), compressedData.size());
+                if (decompressedSize == ZSTD_CONTENTSIZE_ERROR || decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN) {
+                    if(Debug::Instance().printDataInit())
+                        std::cout << "Error: Unable to determine decompressed size for " << filename << std::endl;
+                    return false;
+                }
+
+                std::vector<char> decompressedData(decompressedSize);
+                size_t dSize = ZSTD_decompress(decompressedData.data(), decompressedSize, compressedData.data(), compressedData.size());
+                if (ZSTD_isError(dSize)) {
+                    if(Debug::Instance().printDataInit())
+                        std::cout << "Error: ZSTD decompression failed for " << filename << std::endl;
+                    return false;
+                }
+
+                msgpack::unpack(result, decompressedData.data(), dSize);
+            } else {
             msgpack::unpacker unp;
             bool              finished_parsing;
             constexpr size_t  buffer_size = 1 << 19;
@@ -109,6 +139,7 @@ namespace Tensile
 
                 return nullptr;
             }
+            }
         }
         catch(std::runtime_error const& exc)
         {