1//===- DXContainerEmitter.cpp - Convert YAML to a DXContainer -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Binary emitter for yaml to DXContainer binary
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/BinaryFormat/DXContainer.h"
15#include "llvm/ObjectYAML/ObjectYAML.h"
16#include "llvm/ObjectYAML/yaml2obj.h"
17#include "llvm/Support/Errc.h"
18#include "llvm/Support/Error.h"
19#include "llvm/Support/raw_ostream.h"
20
21using namespace llvm;
22
23namespace {
24class DXContainerWriter {
25public:
26  DXContainerWriter(DXContainerYAML::Object &ObjectFile)
27      : ObjectFile(ObjectFile) {}
28
29  Error write(raw_ostream &OS);
30
31private:
32  DXContainerYAML::Object &ObjectFile;
33
34  Error computePartOffsets();
35  Error validatePartOffsets();
36  Error validateSize(uint32_t Computed);
37
38  void writeHeader(raw_ostream &OS);
39  void writeParts(raw_ostream &OS);
40};
41} // namespace
42
43Error DXContainerWriter::validateSize(uint32_t Computed) {
44  if (!ObjectFile.Header.FileSize)
45    ObjectFile.Header.FileSize = Computed;
46  else if (*ObjectFile.Header.FileSize < Computed)
47    return createStringError(errc::result_out_of_range,
48                             "File size specified is too small.");
49  return Error::success();
50}
51
52Error DXContainerWriter::validatePartOffsets() {
53  if (ObjectFile.Parts.size() != ObjectFile.Header.PartOffsets->size())
54    return createStringError(
55        errc::invalid_argument,
56        "Mismatch between number of parts and part offsets.");
57  uint32_t RollingOffset =
58      sizeof(dxbc::Header) + (ObjectFile.Header.PartCount * sizeof(uint32_t));
59  for (auto I : llvm::zip(ObjectFile.Parts, *ObjectFile.Header.PartOffsets)) {
60    if (RollingOffset > std::get<1>(I))
61      return createStringError(errc::invalid_argument,
62                               "Offset mismatch, not enough space for data.");
63    RollingOffset =
64        std::get<1>(I) + sizeof(dxbc::PartHeader) + std::get<0>(I).Size;
65  }
66  if (Error Err = validateSize(RollingOffset))
67    return Err;
68
69  return Error::success();
70}
71
72Error DXContainerWriter::computePartOffsets() {
73  if (ObjectFile.Header.PartOffsets)
74    return validatePartOffsets();
75  uint32_t RollingOffset =
76      sizeof(dxbc::Header) + (ObjectFile.Header.PartCount * sizeof(uint32_t));
77  ObjectFile.Header.PartOffsets = std::vector<uint32_t>();
78  for (const auto &Part : ObjectFile.Parts) {
79    ObjectFile.Header.PartOffsets->push_back(RollingOffset);
80    RollingOffset += sizeof(dxbc::PartHeader) + Part.Size;
81  }
82  if (Error Err = validateSize(RollingOffset))
83    return Err;
84
85  return Error::success();
86}
87
88void DXContainerWriter::writeHeader(raw_ostream &OS) {
89  dxbc::Header Header;
90  memcpy(Header.Magic, "DXBC", 4);
91  memcpy(Header.FileHash.Digest, ObjectFile.Header.Hash.data(), 16);
92  Header.Version.Major = ObjectFile.Header.Version.Major;
93  Header.Version.Minor = ObjectFile.Header.Version.Minor;
94  Header.FileSize = *ObjectFile.Header.FileSize;
95  Header.PartCount = ObjectFile.Parts.size();
96  if (sys::IsBigEndianHost)
97    Header.swapBytes();
98  OS.write(reinterpret_cast<char *>(&Header), sizeof(Header));
99  SmallVector<uint32_t> Offsets(ObjectFile.Header.PartOffsets->begin(),
100                                ObjectFile.Header.PartOffsets->end());
101  if (sys::IsBigEndianHost)
102    for (auto &O : Offsets)
103      sys::swapByteOrder(O);
104  OS.write(reinterpret_cast<char *>(Offsets.data()),
105           Offsets.size() * sizeof(uint32_t));
106}
107
108void DXContainerWriter::writeParts(raw_ostream &OS) {
109  uint32_t RollingOffset =
110      sizeof(dxbc::Header) + (ObjectFile.Header.PartCount * sizeof(uint32_t));
111  for (auto I : llvm::zip(ObjectFile.Parts, *ObjectFile.Header.PartOffsets)) {
112    if (RollingOffset < std::get<1>(I)) {
113      uint32_t PadBytes = std::get<1>(I) - RollingOffset;
114      OS.write_zeros(PadBytes);
115    }
116    DXContainerYAML::Part P = std::get<0>(I);
117    RollingOffset = std::get<1>(I) + sizeof(dxbc::PartHeader);
118    uint32_t PartSize = P.Size;
119
120    OS.write(P.Name.c_str(), 4);
121    if (sys::IsBigEndianHost)
122      sys::swapByteOrder(P.Size);
123    OS.write(reinterpret_cast<const char *>(&P.Size), sizeof(uint32_t));
124
125    dxbc::PartType PT = dxbc::parsePartType(P.Name);
126
127    uint64_t DataStart = OS.tell();
128    switch (PT) {
129    case dxbc::PartType::DXIL: {
130      if (!P.Program)
131        continue;
132      dxbc::ProgramHeader Header;
133      Header.MajorVersion = P.Program->MajorVersion;
134      Header.MinorVersion = P.Program->MinorVersion;
135      Header.Unused = 0;
136      Header.ShaderKind = P.Program->ShaderKind;
137      memcpy(Header.Bitcode.Magic, "DXIL", 4);
138      Header.Bitcode.MajorVersion = P.Program->DXILMajorVersion;
139      Header.Bitcode.MinorVersion = P.Program->DXILMinorVersion;
140      Header.Bitcode.Unused = 0;
141
142      // Compute the optional fields if needed...
143      if (P.Program->DXILOffset)
144        Header.Bitcode.Offset = *P.Program->DXILOffset;
145      else
146        Header.Bitcode.Offset = sizeof(dxbc::BitcodeHeader);
147
148      if (P.Program->DXILSize)
149        Header.Bitcode.Size = *P.Program->DXILSize;
150      else
151        Header.Bitcode.Size = P.Program->DXIL ? P.Program->DXIL->size() : 0;
152
153      if (P.Program->Size)
154        Header.Size = *P.Program->Size;
155      else
156        Header.Size = sizeof(dxbc::ProgramHeader) + Header.Bitcode.Size;
157
158      uint32_t BitcodeOffset = Header.Bitcode.Offset;
159      if (sys::IsBigEndianHost)
160        Header.swapBytes();
161      OS.write(reinterpret_cast<const char *>(&Header),
162               sizeof(dxbc::ProgramHeader));
163      if (P.Program->DXIL) {
164        if (BitcodeOffset > sizeof(dxbc::BitcodeHeader)) {
165          uint32_t PadBytes = BitcodeOffset - sizeof(dxbc::BitcodeHeader);
166          OS.write_zeros(PadBytes);
167        }
168        OS.write(reinterpret_cast<char *>(P.Program->DXIL->data()),
169                 P.Program->DXIL->size());
170      }
171      break;
172    }
173    case dxbc::PartType::SFI0: {
174      // If we don't have any flags we can continue here and the data will be
175      // zeroed out.
176      if (!P.Flags.has_value())
177        continue;
178      uint64_t Flags = P.Flags->getEncodedFlags();
179      if (sys::IsBigEndianHost)
180        sys::swapByteOrder(Flags);
181      OS.write(reinterpret_cast<char *>(&Flags), sizeof(uint64_t));
182      break;
183    }
184    case dxbc::PartType::HASH: {
185      if (!P.Hash.has_value())
186        continue;
187      dxbc::ShaderHash Hash = {0, {0}};
188      if (P.Hash->IncludesSource)
189        Hash.Flags |= static_cast<uint32_t>(dxbc::HashFlags::IncludesSource);
190      memcpy(&Hash.Digest[0], &P.Hash->Digest[0], 16);
191      if (sys::IsBigEndianHost)
192        Hash.swapBytes();
193      OS.write(reinterpret_cast<char *>(&Hash), sizeof(dxbc::ShaderHash));
194      break;
195    }
196    case dxbc::PartType::Unknown:
197      break; // Skip any handling for unrecognized parts.
198    }
199    uint64_t BytesWritten = OS.tell() - DataStart;
200    RollingOffset += BytesWritten;
201    if (BytesWritten < PartSize)
202      OS.write_zeros(PartSize - BytesWritten);
203    RollingOffset += PartSize;
204  }
205}
206
207Error DXContainerWriter::write(raw_ostream &OS) {
208  if (Error Err = computePartOffsets())
209    return Err;
210  writeHeader(OS);
211  writeParts(OS);
212  return Error::success();
213}
214
215namespace llvm {
216namespace yaml {
217
218bool yaml2dxcontainer(DXContainerYAML::Object &Doc, raw_ostream &Out,
219                      ErrorHandler EH) {
220  DXContainerWriter Writer(Doc);
221  if (Error Err = Writer.write(Out)) {
222    handleAllErrors(std::move(Err),
223                    [&](const ErrorInfoBase &Err) { EH(Err.message()); });
224    return false;
225  }
226  return true;
227}
228
229} // namespace yaml
230} // namespace llvm
231