1 module tagion.wasm.WasmWriter;
2 
3 import std.bitmanip : nativeToLittleEndian;
4 import std.outbuffer;
5 import std.traits : isIntegral, isFloatingPoint, EnumMembers, hasMember, Unqual,
6     TemplateArgsOf, PointerTarget, getUDAs, isPointer, ConstOf, ForeachType, FieldNameTuple;
7 import std.algorithm;
8 import std.array : join;
9 import std.exception : assumeUnique;
10 import std.format;
11 import std.meta : Replace, staticMap;
12 import std.range : lockstep;
13 import std.range.primitives : isInputRange;
14 import std.stdio;
15 import std.typecons : Tuple;
16 
17 import LEB128 = tagion.utils.LEB128;
18 import tagion.wasm.WasmBase;
19 import tagion.wasm.WasmException;
20 import tagion.wasm.WasmReader;
21 
22 @safe class WasmWriter {
23 
24     alias ReaderSections = WasmReader.Sections;
25 
26     alias ReaderCustom = ReaderSections[Section.CUSTOM];
27 
28     alias Sections = SectionsT!(WasmSection);
29 
30     // The first element Custom is Sections sequency is replaced with CustomList
31     alias Modules = Tuple!(Replace!(WasmSection.Custom, WasmSection.CustomList, Sections));
32     //    alias InterfaceModule = InterfaceModuleT!(Sections);
33 
34     alias ReaderSecType(Section sec) = TemplateArgsOf!(ReaderSections[sec].SecRange)[1];
35 
36     Modules mod;
37     Sections[Sec] section(Section Sec)() {
38         if (!mod[Sec]) {
39             mod[Sec] = new Sections[Sec];
40         }
41         return mod[Sec];
42     }
43 
44     this(ref const(WasmReader) reader) {
45         auto loader = new WasmLoader;
46         reader(loader);
47     }
48 
49     this() pure nothrow @nogc {
50         // empty
51     }
52 
53     static WasmWriter opCall(ref const(WasmReader) reader) {
54         return new WasmWriter(reader);
55     }
56 
57     template AsType(T, TList...) {
58         static foreach (E; EnumMembers!Section) {
59             static if (is(T == TList[E])) {
60                 enum AsType = E;
61             }
62         }
63     }
64 
65     enum asType(T) = AsType!(T, staticMap!(PointerTarget, Module.Types));
66 
67     template FromSecType(SecType, TList...) {
68         alias T = WasmSection.SectionT!SecType;
69         static foreach (E; EnumMembers!Section) {
70             static if (is(T == TList[E])) {
71                 enum FromSecType = E;
72             }
73         }
74     }
75 
76     enum fromSecType(T) = FromSecType!(T, Sections);
77 
78     mixin template loadSec(Section sec_type) {
79         enum code = format(q{
80                 final void %s(ref ConstOf!(ReaderSections[Section.%s]) sec) {
81                     enum sec_type=Section.%s;
82                     previous_sec=sec_type;
83                     section_secT!(sec_type)(sec);
84                 }
85             }, secname(sec_type), sec_type, sec_type);
86         mixin(code);
87     }
88 
89     class WasmLoader : WasmReader.InterfaceModule {
90         alias SecElement(Section sec) = TemplateArgsOf!(Sections[sec])[0];
91         private Section previous_sec;
92         void section_secT(Section sec)(ref ConstOf!(ReaderSections[sec]) _reader_sec) {
93             if (_reader_sec !is null) {
94                 alias ModuleType = Sections[sec];
95                 alias SectionElement = TemplateArgsOf!(ModuleType);
96                 auto _sec = new ModuleType;
97                 mod[sec] = _sec;
98                 foreach (s; _reader_sec[]) {
99                     _sec.sectypes ~= SecElement!(sec)(s);
100                 }
101             }
102         }
103 
104         final void custom_sec(ref ConstOf!(ReaderCustom) sec) {
105             mod[Section.CUSTOM].add(previous_sec, sec);
106         }
107 
108         final void start_sec(ref ConstOf!(ReaderSections[Section.START]) sec) {
109             previous_sec = Section.START;
110             mod[Section.START] = new WasmSection.Start(sec);
111         }
112 
113         static foreach (Sec; EnumMembers!Section) {
114             static if (Sec !is Section.START && Sec !is Section.CUSTOM) {
115                 mixin loadSec!Sec;
116 
117             }
118         }
119 
120     }
121 
122     immutable(ubyte[]) serialize() const {
123         OutBuffer[EnumMembers!Section.length + 1] buffers;
124         OutBuffer[EnumMembers!Section.length + 1] custom_buffers;
125         scope (exit) {
126             buffers = null;
127             custom_buffers = null;
128         }
129         size_t output_size;
130         Section previous_sec;
131         void output_custom(const(WasmSection.Custom) custom) {
132             if (custom) {
133                 custom_buffers[previous_sec] = new OutBuffer;
134                 custom_buffers[previous_sec].reserve(custom.guess_size);
135                 custom.serialize(custom_buffers[previous_sec]);
136             }
137         }
138 
139         foreach (E; EnumMembers!Section) {
140             foreach (const sec; mod[Section.CUSTOM].list[previous_sec]) {
141                 output_custom(sec);
142             }
143             static if (E !is Section.CUSTOM) {
144                 if (mod[E]!is null) {
145                     buffers[E] = new OutBuffer;
146                     static if (E !is Section.CUSTOM) {
147                         mod[E].serialize(buffers[E]);
148                         output_size += buffers[E].offset + uint.sizeof + Section.sizeof;
149                     }
150                 }
151             }
152             previous_sec = E;
153         }
154         foreach (const sec; mod[Section.CUSTOM].list[previous_sec]) {
155             output_custom(sec);
156         }
157         previous_sec = Section.CUSTOM;
158         auto output = new OutBuffer;
159         output_size += magic.length + wasm_version.length;
160         output.reserve(output_size);
161         output.write(magic);
162         output.write(wasm_version);
163         void append_buffer(const OutBuffer b, const(Section) sec) {
164             if (b !is null) {
165                 output.write(cast(ubyte) sec);
166                 output.write(LEB128.encode(b.offset));
167                 output.write(b);
168             }
169         }
170 
171         foreach (E; EnumMembers!Section) {
172             append_buffer(buffers[E], E);
173             append_buffer(custom_buffers[E], Section.CUSTOM);
174         }
175         append_buffer(custom_buffers[$ - 1], Section.CUSTOM);
176         return output.toBytes.idup;
177     }
178 
179     struct WasmSection {
180         mixin template Serialize() {
181             final void serialize(ref OutBuffer bout) const {
182                 alias MainType = typeof(this);
183                 static if (hasMember!(MainType, "guess_size")) {
184                     bout.reserve(guess_size);
185                 }
186                 foreach (i, m; this.tupleof) {
187                     alias T = typeof(m);
188                     static if (is(T == struct) || is(T == class)) {
189                         m.serialize(bout);
190                     }
191                     else {
192                         static if (T.sizeof == 1) {
193                             bout.write(cast(ubyte) m);
194                         }
195                         else static if (isIntegral!T) {
196                             bout.write(LEB128.encode(m));
197                         }
198                         else static if (isFloatingPoint!T) {
199                             bout.write(nativeToLittleEndian(m));
200                         }
201                         else static if (is(T : U[], U)) {
202                             alias spec = getUDAs!(this.tupleof[i], Section);
203                             static if ((spec.length == 0) || (spec[0]!is Section.CODE)) {
204                                 // Check to avoid adding the length for an expression
205                                 bout.write(LEB128.encode(m.length));
206                             }
207                             static if (U.sizeof == 1) {
208                                 bout.write(cast(const(ubyte[])) m);
209                             }
210                             else static if (isIntegral!U) {
211                                 m.each!((e) => bout.write(LEB128.encode(e)));
212                             }
213                             else static if (hasMember!(U, "serialize")) {
214                                 foreach (e; m) {
215                                     e.serialize(bout);
216                                 }
217                             }
218                             else {
219                                 static assert(0,
220                                         format("Array type %s is not supported", T.stringof));
221                             }
222                         }
223                     else {
224                             static assert(0, format("Type %s is not supported", T.stringof));
225                         }
226                     }
227                 }
228             }
229         }
230 
231         struct Limit {
232             Limits lim;
233             uint from;
234             uint to;
235             this(ref const(WasmReader.Limit) l) {
236                 lim = l.lim;
237                 from = l.from;
238                 to = l.to;
239             }
240 
241             void serialize(ref OutBuffer bout) const {
242                 bout.write(cast(ubyte) lim);
243                 bout.write(LEB128.encode(from));
244                 with (Limits) {
245                     final switch (lim) {
246                     case INFINITE:
247                         // Empty
248                         break;
249                     case RANGE:
250                         bout.write(LEB128.encode(to));
251                         break;
252                     }
253                 }
254             }
255         }
256 
257         static class SectionT(SecType) {
258             SecType[] sectypes;
259             @property size_t length() const pure nothrow {
260                 return sectypes.length;
261             }
262 
263             size_t guess_size() const pure nothrow {
264                 if (sectypes.length > 0) {
265                     static if (hasMember!(SecType, "guess_size")) {
266                         return sectypes.map!(s => s.guess_size()).sum + uint.sizeof;
267                     }
268                     else {
269                         return sectypes.length * SecType.sizeof + uint.sizeof;
270                     }
271                 }
272                 return 0;
273             }
274 
275             mixin Serialize;
276         }
277 
278         static class Custom {
279             string name;
280             immutable(ubyte)[] bytes;
281             size_t guess_size() const pure nothrow {
282                 return name.length + bytes.length + uint.sizeof * 2;
283             }
284 
285             this(string name, immutable(ubyte[]) bytes) pure nothrow {
286                 this.name = name;
287                 this.bytes = bytes;
288             }
289 
290             import tagion.hibon.Document;
291 
292             this(string name, const(Document) doc) pure nothrow {
293                 this.name = name;
294                 bytes = doc.data[doc.begin .. $];
295             }
296 
297             this(_ReaderCustom)(const(_ReaderCustom) s) pure nothrow {
298                 name = s.name;
299                 bytes = s.bytes;
300             }
301 
302             mixin Serialize;
303         }
304 
305         struct CustomList {
306             Custom[][EnumMembers!(Section).length + 1] list;
307             void add(_ReaderCustom)(const size_t sec_index, const(_ReaderCustom) s) {
308                 list[sec_index] ~= new Custom(s);
309             }
310 
311         }
312 
313         struct FuncType {
314             Types type;
315             immutable(Types)[] params;
316             immutable(Types)[] results;
317             size_t guess_size() const pure nothrow {
318                 return params.length + results.length + uint.sizeof * 2 + Types.sizeof;
319             }
320 
321             this(const Types type, immutable(Types)[] params, immutable(Types)[] results) {
322                 this.type = type;
323                 this.params = params;
324                 this.results = results;
325             }
326 
327             this(ref const(ReaderSecType!(Section.TYPE)) s) {
328                 type = s.type;
329                 params = s.params;
330                 results = s.results;
331             }
332 
333             mixin Serialize;
334         }
335 
336         alias Type = SectionT!(FuncType);
337 
338         struct ImportType {
339             string mod;
340             string name;
341             ImportDesc importdesc;
342             alias ReaderImportType = ReaderSecType!(Section.IMPORT);
343             alias ReaderImportDesc = ReaderImportType.ImportDesc;
344             size_t guess_size() const pure nothrow {
345                 return mod.length + name.length + uint.sizeof * 2 + ImportDesc.sizeof;
346             }
347 
348             mixin Serialize;
349             struct ImportDesc {
350                 struct FuncDesc {
351                     uint funcidx;
352                     this(const(ReaderImportDesc.FuncDesc) f) {
353                         funcidx = f.funcidx;
354                     }
355 
356                     mixin Serialize;
357                 }
358 
359                 struct TableDesc {
360                     Types type;
361                     Limit limit;
362                     this(const(ReaderImportDesc.TableDesc) t) {
363                         type = t.type;
364                         limit = t.limit;
365                     }
366 
367                     mixin Serialize;
368                 }
369 
370                 struct MemoryDesc {
371                     Limit limit;
372                     this(const(ReaderImportDesc.MemoryDesc) m) {
373                         limit = m.limit;
374                     }
375 
376                     mixin Serialize;
377                 }
378 
379                 struct GlobalDesc {
380                     Types type;
381                     Mutable mut;
382                     this(const Types type, const Mutable mut = Mutable.CONST) {
383                         this.type = type;
384                         this.mut = mut;
385                     }
386 
387                     this(const(ReaderImportDesc.GlobalDesc) g) {
388                         mut = g.mut;
389                         type = g.type;
390                     }
391 
392                     mixin Serialize;
393                 }
394 
395                 protected union {
396                     @(IndexType.FUNC) FuncDesc _funcdesc;
397                     @(IndexType.TABLE) TableDesc _tabledesc;
398                     @(IndexType.MEMORY) MemoryDesc _memorydesc;
399                     @(IndexType.GLOBAL) GlobalDesc _globaldesc;
400                 }
401 
402                 protected IndexType _desc;
403                 void serialize(ref OutBuffer bout) const {
404                     with (IndexType)
405                         bout.write(cast(ubyte) _desc);
406                     final switch (_desc) {
407                         foreach (E; EnumMembers!IndexType) {
408                     case E:
409                             get!E.serialize(bout);
410                             break;
411                         }
412                     }
413                 }
414 
415                 auto get(IndexType IType)() const pure
416                 in {
417                     assert(_desc is IType);
418                 }
419                 do {
420                     //static foreach (m; __traits(allMembers, ImportDesc)) {
421                     static foreach (m; FieldNameTuple!ImportDesc) {
422                         {
423                             enum get_indextype_code = format(q{enum get_indextype=getUDAs!(%s, IndexType);},
424                                         m);
425                             mixin(get_indextype_code);
426                             static if (get_indextype.length is 1) {
427                                 static if (IType is get_indextype[0]) {
428                                     enum return_code = format(q{auto result=%s;}, m);
429                                     mixin(return_code);
430                                     return result;
431                                 }
432                             }
433                         }
434                     }
435                 }
436 
437                 @property IndexType desc() const pure nothrow {
438                     return _desc;
439                 }
440 
441                 this(T)(ref const(T) desc) {
442                     with (IndexType) {
443                         static if (is(T : const(FuncDesc))) {
444                             _desc = FUNC;
445                             _funcdesc = desc;
446                         }
447                         else static if (is(T : const(TableDesc))) {
448                             _desc = TABLE;
449                             _tabledesc = desc;
450                         }
451                         else static if (is(T : const(MemoryDesc))) {
452                             _desc = MEMORY;
453                             _memorydesc = desc;
454                         }
455                         else static if (is(T : const(GlobalDesc))) {
456                             _desc = GLOBAL;
457                             _globaldesc = desc;
458                         }
459                         else {
460                             static assert(0, format("Type %s is not supported", T.stringof));
461                         }
462                     }
463                 }
464 
465                 this(ref const(ReaderImportDesc) importdesc) {
466                     with (IndexType) {
467                         final switch (importdesc.desc) {
468                         case FUNC:
469                             _funcdesc = FuncDesc(importdesc.get!(FUNC));
470                             break;
471                         case TABLE:
472                             _tabledesc = TableDesc(importdesc.get!(TABLE));
473                             break;
474                         case MEMORY:
475                             _memorydesc = MemoryDesc(importdesc.get!(MEMORY));
476                             break;
477                         case GLOBAL:
478                             _globaldesc = GlobalDesc(importdesc.get!(GLOBAL));
479                             break;
480                         }
481                     }
482                 }
483             }
484 
485             this(T)(string mod, string name, T desc) pure {
486                 this.mod = mod;
487                 this.name = name;
488                 this.importdesc = ImportDesc(desc);
489             }
490 
491             this(ref const(ReaderImportType) s) {
492                 this.mod = s.mod;
493                 this.name = s.name;
494                 this.importdesc = ImportDesc(s.importdesc);
495             }
496 
497         }
498 
499         alias Import = SectionT!(ImportType);
500 
501         struct TypeIndex {
502             uint idx;
503             this(const uint typeidx) {
504                 this.idx = typeidx;
505             }
506 
507             this(ref const(ReaderSecType!(Section.FUNCTION)) f) {
508                 idx = f.idx;
509             }
510 
511             mixin Serialize;
512         }
513 
514         alias Function = SectionT!(TypeIndex);
515 
516         struct TableType {
517             Types type;
518             Limit limit;
519             this(ref const(ReaderSecType!(Section.TABLE)) t) {
520                 type = t.type;
521                 limit = Limit(t.limit);
522             }
523 
524             mixin Serialize;
525         }
526 
527         alias Table = SectionT!(TableType);
528 
529         struct MemoryType {
530             Limit limit;
531             this(ref const(ReaderSecType!(Section.MEMORY)) m) {
532                 limit = Limit(m.limit);
533             }
534 
535             mixin Serialize;
536         }
537 
538         alias Memory = SectionT!(MemoryType);
539 
540         struct GlobalType {
541             alias GlobalDesc = ImportType.ImportDesc.GlobalDesc;
542             GlobalDesc global;
543             @Section(Section.CODE) immutable(ubyte)[] expr;
544             this(const GlobalDesc global, immutable(ubyte)[] expr) {
545                 this.global = global;
546                 this.expr = expr;
547             }
548 
549             this(ref const(ReaderSecType!(Section.GLOBAL)) g) {
550                 global = ImportType.ImportDesc.GlobalDesc(g.global);
551                 expr = g.expr;
552             }
553 
554             mixin Serialize;
555         }
556 
557         alias Global = SectionT!(GlobalType);
558 
559         struct ExportType {
560             string name;
561             IndexType desc;
562             uint idx;
563             size_t guess_size() const pure nothrow {
564                 return name.length + uint.sizeof + ImportType.ImportDesc.sizeof;
565             }
566 
567             this(string name, const uint idx, const IndexType desc = IndexType.FUNC) {
568                 this.name = name;
569                 this.desc = desc;
570                 this.idx = idx;
571             }
572 
573             this(ref const(ReaderSecType!(Section.EXPORT)) e) {
574                 name = e.name;
575                 desc = IndexType(e.desc);
576                 idx = e.idx;
577             }
578 
579             mixin Serialize;
580         }
581 
582         alias Export = SectionT!(ExportType);
583 
584         static class Start {
585             uint idx;
586             alias ReaderStartType = ReaderSections[Section.START];
587             this(ref ConstOf!(ReaderStartType) s) {
588                 idx = s.idx;
589             }
590 
591             mixin Serialize;
592         }
593 
594         struct ElementType {
595             uint tableidx;
596             @Section(Section.CODE) immutable(ubyte)[] expr;
597             immutable(uint)[] funcs;
598             this(ref const(ReaderSecType!(Section.ELEMENT)) e) {
599                 tableidx = e.tableidx;
600                 expr = e.expr;
601                 funcs = e.funcs;
602             }
603 
604             mixin Serialize;
605         }
606 
607         alias Element = SectionT!(ElementType);
608 
609         struct CodeType {
610             Local[] locals;
611             @Section(Section.CODE) immutable(ubyte)[] expr;
612             size_t guess_size() const pure nothrow {
613                 return locals.length * Local.sizeof + expr.length + 2 * uint.sizeof;
614             }
615 
616             struct Local {
617                 uint count;
618                 Types type;
619                 mixin Serialize;
620             }
621 
622             static Local[] toLocals(scope const(Types[]) types) pure nothrow {
623                 Local[] result;
624                 void compact(const(Types[]) _types) {
625                     if (_types.length) {
626                         const count = cast(uint) _types.count(_types[0]);
627                         result ~= Local(count, _types[0]);
628                         compact(_types[count .. $]);
629                     }
630                 }
631 
632                 return result;
633 
634             }
635 
636             this(Local[] locals, immutable(ubyte[]) expr) {
637                 this.locals = locals;
638                 this.expr = expr;
639             }
640 
641             this(scope const(Types[]) types, immutable(ubyte[]) expr) {
642                 this.locals = toLocals(types);
643                 this.expr = expr;
644             }
645 
646             @trusted this(ref const(ReaderSecType!(Section.CODE)) c) {
647                 locals = new Local[c.locals.length];
648                 foreach (ref l, reader_l; lockstep(locals, c.locals)) {
649                     l.count = reader_l.count;
650                     l.type = reader_l.type;
651                 }
652                 expr = c[].data;
653             }
654 
655             ExprRange opSlice() const {
656                 return ExprRange(expr);
657             }
658 
659             void serialize(ref OutBuffer bout) const {
660                 auto tmp_out = new OutBuffer;
661                 tmp_out.reserve(guess_size);
662                 tmp_out.write(LEB128.encode(locals.length));
663                 locals.each!((l) => l.serialize(tmp_out));
664                 tmp_out.write(expr);
665                 bout.write(LEB128.encode(tmp_out.offset));
666                 bout.write(tmp_out.toBytes);
667             }
668 
669             immutable(ubyte[]) serialize() const @trusted {
670                 auto bout = new OutBuffer;
671                 serialize(bout);
672                 return assumeUnique(bout.toBytes);
673             }
674         }
675 
676         alias Code = SectionT!(CodeType);
677 
678         struct DataType {
679             uint idx;
680             @Section(Section.CODE) immutable(ubyte)[] expr;
681             string base;
682             this(ref const(ReaderSecType!(Section.DATA)) d) {
683                 idx = d.idx;
684                 expr = d.expr;
685                 base = d.base;
686             }
687 
688             mixin Serialize;
689         }
690 
691         alias Data = SectionT!(DataType);
692 
693     }
694 }
695 
696 version (none) unittest {
697     import std.exception : assumeUnique;
698     import std.file;
699     import std.stdio;
700     import tagion.wavm.Wast;
701 
702     @trusted static immutable(ubyte[]) fread(R)(R name, size_t upTo = size_t.max) {
703         import std.file : _read = read;
704 
705         auto data = cast(ubyte[]) _read(name, upTo);
706         // writefln("read data=%s", data);
707         return assumeUnique(data);
708     }
709 
710     //    string filename="../tests/wasm/func_1.wasm";
711     string filename = "../tests/wasm/global_1.wasm";
712     //    string filename="../tests/wasm/imports_1.wasm";
713     //    string filename="../tests/wasm/table_copy_2.wasm";
714     //    string filename="../tests/wasm/memory_2.wasm";
715     //    string filename="../tests/wasm/start_4.wasm";
716     //    string filename="../tests/wasm/address_1.wasm";
717     //    string filename="../tests/wasm/data_4.wasm";
718     //    string filename="../tests/web_gas_gauge.wasm";//wasm/imports_1.wasm";
719     immutable read_data = fread(filename);
720     auto wasm_reader = WasmReader(read_data);
721     Wast(wasm_reader, stdout).serialize();
722 
723     writefln("wasm_reader.serialize=%s", wasm_reader.serialize);
724     auto wasm_writer = WasmWriter(wasm_reader);
725 
726     writeln("wasm_writer.serialize");
727     writefln("wasm_writer.serialize=%s", wasm_writer.serialize);
728     assert(wasm_reader.serialize == wasm_writer.serialize);
729 }