1 module tagion.wasm.WasmGas;
2 
3 import tagion.wasm.WasmBase;
4 import tagion.wasm.WasmExpr;
5 import tagion.wasm.WasmWriter;
6 import tagion.utils.LEB128;
7 
8 import std.algorithm.comparison : max;
9 import std.algorithm.iteration : map;
10 import std.array : array;
11 import std.format;
12 import std.meta : staticMap;
13 import std.outbuffer;
14 import std.traits : EnumMembers, PointerTarget, TemplateArgsOf, Unqual;
15 import std.typecons : Tuple;
16 
17 // import std.stdio;
18 
19 struct WasmGas {
20     enum set_gas_gauge = "$set_gas_gauge";
21     enum read_gas_gauge = "$read_gas_gauge";
22     protected WasmWriter writer;
23 
24     this(ref WasmWriter writer) {
25         this.writer = writer;
26     }
27 
28     alias GlobalDesc = WasmWriter.WasmSection.ImportType.ImportDesc.GlobalDesc;
29     alias Global = WasmWriter.WasmSection.Global;
30     alias Type = WasmWriter.WasmSection.Type;
31     alias Function = WasmWriter.WasmSection.Function;
32     alias Code = WasmWriter.WasmSection.Code;
33     alias GlobalType = WasmWriter.WasmSection.GlobalType;
34     alias FuncType = WasmWriter.WasmSection.FuncType;
35     alias TypeIndex = WasmWriter.WasmSection.TypeIndex;
36     alias CodeType = WasmWriter.WasmSection.CodeType;
37     alias ExportType = WasmWriter.WasmSection.ExportType;
38 
39     /++
40      Append and section type to the specified section
41      Returns:
42      the index of the inserted sectype
43      +/
44     uint inject(SecType)(SecType sectype) {
45         uint idx;
46         enum SectionId = WasmWriter.fromSecType!SecType;
47         if (writer.mod[SectionId] is null) {
48             idx = 0;
49             writer.mod[SectionId] = new WasmWriter.WasmSection.SectionT!SecType;
50             writer.mod[SectionId].sectypes = [sectype];
51         }
52         else {
53             idx = cast(uint)(writer.mod[SectionId].sectypes.length);
54             writer.mod[SectionId].sectypes ~= sectype;
55         }
56         return idx;
57     }
58 
59     alias InjectGas = void delegate(scope OutBuffer bout, const uint gas);
60     package void perform_gas_inject(InjectGas inject_gas) {
61         auto code_sec = writer.mod[Section.CODE];
62 
63         alias GasResult = Tuple!(uint, "gas", IR, "irtype");
64 
65         const(GasResult) inject_gas_funcs(ref scope OutBuffer bout, ref ExprRange expr) {
66             scope wasmexpr = WasmExpr(bout);
67             uint gas_count;
68             while (!expr.empty) {
69                 const elm = expr.front;
70                 const instr = instrTable.get(elm.code, illegalInstr);
71                 gas_count += instr.cost;
72                 expr.popFront;
73                 with (IRType) {
74                     final switch (instr.irtype) {
75                     case PREFIX:
76                     case CODE:
77                         wasmexpr(elm.code);
78                         break;
79                     case BLOCK:
80                         wasmexpr(elm.code, elm.types[0]);
81                         scope block_bout = new OutBuffer;
82                         pragma(msg, "fixme(cbr): add block_block_out.reserve");
83                         const block_result = inject_gas_funcs(block_bout, expr);
84                         if (elm.code is IR.IF) {
85                             int if_gas_count = block_result.gas;
86                             if (block_result.irtype is IR.ELSE) {
87                                 const endif_result = inject_gas_funcs(block_bout, expr);
88                                 if_gas_count = max(endif_result.gas, if_gas_count);
89                             }
90                             gas_count += if_gas_count;
91                         }
92                         else {
93                             inject_gas(bout, block_result.gas);
94                         }
95                         bout.write(block_bout);
96                         break;
97                     case BRANCH:
98                     case BRANCH_IF:
99                         wasmexpr(elm.code, elm.warg.get!uint);
100                         break;
101                     case BRANCH_TABLE:
102                         const branch_idxs = elm.wargs.map!((a) => a.get!uint).array;
103                         wasmexpr(elm.code, branch_idxs);
104                         break;
105                     case CALL, LOCAL, GLOBAL, CALL_INDIRECT:
106                         wasmexpr(elm.code, elm.warg.get!uint);
107                         //writefln("\t\tdata=%s",
108                         break;
109                     case MEMORY:
110                         wasmexpr(elm.code, elm.wargs[0].get!uint, elm.wargs[1].get!uint);
111                         break;
112                     case MEMOP:
113                         wasmexpr(elm.code);
114                         break;
115                     case CONST:
116                         with (IR) {
117                             switch (elm.code) {
118                             case I32_CONST:
119                                 wasmexpr(elm.code, elm.warg.get!int);
120                                 break;
121                             case I64_CONST:
122                                 wasmexpr(elm.code, elm.warg.get!long);
123                                 break;
124                             case F32_CONST:
125                                 wasmexpr(elm.code, elm.warg.get!float);
126                                 break;
127                             case F64_CONST:
128                                 wasmexpr(elm.code, elm.warg.get!double);
129                                 break;
130                             default:
131                                 assert(0, format("Instruction %s is not a const", elm.code));
132                             }
133                         }
134                         break;
135                     case END:
136                         wasmexpr(elm.code);
137                         return GasResult(gas_count, elm.code);
138                     case ILLEGAL:
139                         assert(0, format("Illegal opcode %02X", elm.code));
140                         break;
141                     case SYMBOL:
142                         assert(0, "Symbol opcode and it does not have an equivalent opcode");
143                     }
144                 }
145             }
146             return GasResult(gas_count, IR.END);
147         }
148 
149         if (code_sec) {
150             foreach (ref c; code_sec.sectypes) {
151                 scope expr_bout = new OutBuffer;
152                 auto expr_range = c[];
153                 expr_bout.reserve(c.expr.length * 5 / 4); // add 25%
154                 const gas_result = inject_gas_funcs(expr_bout, expr_range);
155                 scope code_bout = new OutBuffer;
156                 code_bout.reserve(expr_bout.offset + 2 * uint.sizeof);
157                 inject_gas(code_bout, gas_result.gas);
158                 code_bout.write(expr_bout);
159                 c.expr = code_bout.toBytes.idup;
160             }
161         }
162     }
163 
164     void modify() {
165         /+
166          Inject the Global variable
167          +/
168         GlobalType global_type;
169         {
170             scope out_expr = new OutBuffer;
171             WasmExpr(out_expr)(IR.I32_CONST, 0)(IR.END);
172             GlobalDesc global_desc = GlobalDesc(Types.I32, Mutable.VAR);
173             immutable expr = out_expr.toBytes.idup;
174             global_type = GlobalType(global_desc, expr);
175         }
176         const global_idx = inject(global_type);
177         const type_sec = writer.mod[Section.TYPE];
178         const gas_count_func_idx = cast(uint)((type_sec is null) ? 0 : type_sec.sectypes.length);
179         // writefln("func_sec.sectypes=%s", func_sec.sectypes);
180 
181         // writefln("gas_count_func_idx=%d", gas_count_func_idx);
182         void inject_gas_count(scope OutBuffer bout, const uint gas) {
183             if (gas > 0) {
184                 WasmExpr(bout)(IR.I32_CONST, gas)(IR.CALL, gas_count_func_idx);
185             }
186         }
187 
188         perform_gas_inject(&inject_gas_count); //gas_count_func_idx);
189 
190         { // Gas down counter
191             FuncType func_type = FuncType(Types.FUNC, [Types.I32], null);
192             const type_idx = inject(func_type);
193 
194             TypeIndex func_index = TypeIndex(type_idx);
195             const func_idx = inject(func_index);
196 
197             CodeType code_type;
198             {
199                 scope out_expr = new OutBuffer;
200                 // dfmt off
201                 WasmExpr(out_expr)
202                     (IR.LOCAL_SET, 0)
203                     (IR.GLOBAL_GET, global_idx)
204                     (IR.I32_CONST, 0)
205                     (IR.I32_GT_S)
206                     (IR.IF, Types.EMPTY)
207                     (IR.GLOBAL_GET, global_idx)
208                     (IR.LOCAL_GET, 0)
209                     (IR.I32_SUB)
210                     (IR.GLOBAL_SET, global_idx)
211                     (IR.ELSE)
212                     (IR.UNREACHABLE)
213                     (IR.END)
214                     (IR.END);
215                 // dfmt on
216                 immutable expr = out_expr.toBytes.idup;
217                 code_type = CodeType([CodeType.Local(1, Types.I32)], expr);
218             }
219             const code_idx = inject(code_type);
220         }
221         { // set_gas_gauge
222             /+
223              Inject the function type to the set_gas_gauge
224              +/
225             FuncType func_type = FuncType(Types.FUNC, [Types.I32], null);
226             const type_idx = inject(func_type);
227             /+
228              Inject the function header index to the set_gas_gauge
229              +/
230             TypeIndex func_index = TypeIndex(type_idx); //Types.FUNC, [Types.I32], null);
231             const func_idx = inject(func_index);
232             /+
233              Inject the function body to the set_gas_gauage
234              +/
235             CodeType code_type;
236             {
237                 scope out_expr = new OutBuffer;
238                 // dfmt off
239                 WasmExpr(out_expr)
240                     // void $set_gas_gauge(i32 $gas)
241                     //   if ( $gas_gauge != 0) {
242                     (IR.GLOBAL_GET, global_idx)(IR.I32_EQZ)(IR.IF, Types.EMPTY)
243                     //       exit;
244                     (IR.UNREACHABLE)
245                     //   } else {
246                     (IR.ELSE)
247                     //     $gas_gauge=$gas;
248                     (IR.GLOBAL_SET, global_idx)
249                     //   }
250                     (IR.END)
251                     //}
252                     (IR.END);
253                 // dfmt off
254                 immutable expr=out_expr.toBytes.idup;
255                 code_type=CodeType([CodeType.Local(1, Types.I32)] , expr);
256             }
257             const code_idx=inject(code_type);
258 
259             ExportType export_type=ExportType(set_gas_gauge, func_idx);
260             const export_idx=inject(export_type);
261         }
262         { // read_gas_gauge
263             FuncType func_type=FuncType(Types.FUNC, null, [Types.I32]);
264             const type_idx=inject(func_type);
265 
266             TypeIndex func_index=TypeIndex(type_idx); //Types.FUNC, [Types.I32], null);
267             const func_idx=inject(func_index);
268 
269             CodeType code_type;
270             {
271                 scope out_expr=new OutBuffer;
272                 // dfmt off
273                 WasmExpr(out_expr)
274                     (IR.GLOBAL_GET, global_idx)
275                     (IR.END);
276                 // dfmt on
277                 immutable expr = out_expr.toBytes.idup;
278                 code_type = CodeType(Types[].init, expr);
279             }
280             const code_idx = inject(code_type);
281 
282             ExportType export_type = ExportType(read_gas_gauge, func_idx);
283             const export_idx = inject(export_type);
284         }
285     }
286 }
287 
288 version (none) unittest {
289     import std.exception : assumeUnique;
290     import std.file;
291     import std.stdio;
292     import tagion.wasm.WasmReader;
293     import tagion.wasm.Wast;
294 
295     //      import std.file : fread=read, fwrite=write;
296 
297     @trusted static immutable(ubyte[]) fread(R)(R name, size_t upTo = size_t.max) {
298         import std.file : _read = read;
299 
300         auto data = cast(ubyte[]) _read(name, upTo);
301         // writefln("read data=%s", data);
302         return assumeUnique(data);
303     }
304 
305     //    string filename="../tests/wasm/func_1.wasm";
306     string filename = "../tests/wasm/global_1.wasm";
307     //    string filename="../tests/wasm/imports_1.wasm";
308     //    string filename="../tests/wasm/table_copy_2.wasm";
309     //    string filename="../tests/wasm/memory_2.wasm";
310     //    string filename="../tests/wasm/start_4.wasm";
311     //    string filename="../tests/wasm/address_1.wasm";
312     //    string filename="../tests/wasm/data_4.wasm";
313     //    string filename="../tests/web_gas_gauge.wasm";//wasm/imports_1.wasm";
314     immutable read_data = fread(filename);
315     auto wasm_reader = WasmReader(read_data);
316     Wast(wasm_reader, stdout).serialize();
317     //Wast(WasmReader(wasm_writer.serialize), stdout).serialize;
318 
319     //writefln("wasm_reader.serialize=%s", wasm_reader.serialize);
320     auto wasm_writer = WasmWriter(wasm_reader);
321 
322     //writeln("wasm_writer.serialize");
323     //writefln("wasm_writer.serialize=%s", wasm_writer.serialize);
324     assert(wasm_reader.serialize == wasm_writer.serialize);
325     auto wasmgas = WasmGas(wasm_writer);
326     wasmgas.modify;
327     {
328         Wast(WasmReader(wasm_writer.serialize), stdout).serialize;
329     }
330 }