diff --git a/src/lj_asm.c b/src/lj_asm.c index f38ceaef..34357e95 100644 --- a/src/lj_asm.c +++ b/src/lj_asm.c @@ -3212,8 +3212,14 @@ static void asm_tail_link(ASMState *as) if (as->T->link == TRACE_INTERP) { /* Setup fixed registers for exit to interpreter. */ + const BCIns *pc = snap_pc(as->T->snapmap[snap->mapofs + snap->nent]); + if (bc_op(*pc) == BC_JLOOP) { /* NYI: find a better way to do this. */ + BCIns *retpc = &as->J->trace[bc_d(*pc)]->startins; + if (bc_isret(bc_op(*retpc))) + pc = retpc; + } emit_loada(as, RID_DISPATCH, J2GG(as->J)->dispatch); - emit_loada(as, RID_PC, snap_pc(as->T->snapmap[snap->mapofs + snap->nent])); + emit_loada(as, RID_PC, pc); } else if (baseslot) { /* Save modified BASE for linking to trace with higher start frame. */ emit_setgl(as, RID_BASE, jit_base); diff --git a/src/lj_bc.h b/src/lj_bc.h index e1284916..74b11698 100644 --- a/src/lj_bc.h +++ b/src/lj_bc.h @@ -245,6 +245,11 @@ typedef enum { (BCM##ma|(BCM##mb<<3)|(BCM##mc<<7)|(MM_##mm<<11)), #define BCMODE_FF 0 +static LJ_AINLINE int bc_isret(BCOp op) +{ + return (op == BC_RETM || op == BC_RET || op == BC_RET0 || op == BC_RET1); +} + LJ_DATA const uint16_t lj_bc_mode[]; LJ_DATA const uint16_t lj_bc_ofs[]; diff --git a/src/lj_dispatch.c b/src/lj_dispatch.c index 83bb4fd8..f956aa1b 100644 --- a/src/lj_dispatch.c +++ b/src/lj_dispatch.c @@ -380,11 +380,8 @@ void LJ_FASTCALL lj_dispatch_ins(lua_State *L, const BCIns *pc) L->top = L->base + slots; /* Fix top again. */ } } - if ((g->hookmask & LUA_MASKRET)) { - BCOp op = bc_op(pc[-1]); - if (op == BC_RETM || op == BC_RET || op == BC_RET0 || op == BC_RET1) - callhook(L, LUA_HOOKRET, -1); - } + if ((g->hookmask & LUA_MASKRET) && bc_isret(bc_op(pc[-1]))) + callhook(L, LUA_HOOKRET, -1); } /* Initialize call. Ensure stack space and clear missing parameters. */ diff --git a/src/lj_jit.h b/src/lj_jit.h index 69156218..3201baf0 100644 --- a/src/lj_jit.h +++ b/src/lj_jit.h @@ -287,6 +287,9 @@ typedef struct jit_State { TraceNo parent; /* Parent of current side trace (0 for root traces). */ ExitNo exitno; /* Exit number in parent of current side trace. */ + BCIns *patchpc; /* PC for pending re-patch. */ + BCIns patchins; /* Instruction for pending re-patch. */ + TValue errinfo; /* Additional info element for trace errors. */ MCode *mcarea; /* Base of current mcode area. */ diff --git a/src/lj_record.c b/src/lj_record.c index da9c221c..62f5c066 100644 --- a/src/lj_record.c +++ b/src/lj_record.c @@ -522,6 +522,29 @@ static void rec_tailcall(jit_State *J, BCReg func, ptrdiff_t nargs) lj_trace_err(J, LJ_TRERR_LUNROLL); } +/* Check unroll limits for down-recursion. */ +static int check_downrec_unroll(jit_State *J, GCproto *pt) +{ + IRRef ptref; + for (ptref = J->chain[IR_KGC]; ptref; ptref = IR(ptref)->prev) + if (ir_kgc(IR(ptref)) == obj2gco(pt)) { + int count = 0; + IRRef ref; + for (ref = J->chain[IR_RETF]; ref; ref = IR(ref)->prev) + if (IR(ref)->op1 == ptref) + count++; + if (count) { + if (J->pc == J->startpc) { + if (count + J->tailcalled > J->param[JIT_P_recunroll]) + return 1; + } else { + lj_trace_err(J, LJ_TRERR_DOWNREC); + } + } + } + return 0; +} + /* Record return. */ static void rec_ret(jit_State *J, BCReg rbase, ptrdiff_t gotresults) { @@ -545,6 +568,15 @@ static void rec_ret(jit_State *J, BCReg rbase, ptrdiff_t gotresults) BCIns callins = *(frame_pc(frame)-1); ptrdiff_t nresults = bc_b(callins) ? (ptrdiff_t)bc_b(callins)-1 :gotresults; BCReg cbase = bc_a(callins); + GCproto *pt = funcproto(frame_func(frame - (cbase+1))); + if (J->pt && frame == J->L->base - 1) { + if (J->framedepth == 0 && check_downrec_unroll(J, pt)) { + J->maxslot = rbase + nresults; + rec_stop(J, J->curtrace); /* Down-recursion. */ + return; + } + lj_snap_add(J); + } for (i = 0; i < nresults; i++) /* Adjust results. */ J->base[i-1] = i < gotresults ? J->base[rbase+i] : TREF_NIL; J->maxslot = cbase+(BCReg)nresults; @@ -553,11 +585,10 @@ static void rec_ret(jit_State *J, BCReg rbase, ptrdiff_t gotresults) lua_assert(J->baseslot > cbase+1); J->baseslot -= cbase+1; J->base -= cbase+1; - } else if (J->parent == 0) { + } else if (J->parent == 0 && !bc_isret(bc_op(J->cur.startins))) { /* Return to lower frame would leave the loop in a root trace. */ lj_trace_err(J, LJ_TRERR_LLEAVE); } else { /* Return to lower frame. Guard for the target we return to. */ - GCproto *pt = funcproto(frame_func(frame - (cbase+1))); TRef trpt = lj_ir_kgc(J, obj2gco(pt), IRT_PROTO); TRef trpc = lj_ir_kptr(J, (void *)frame_pc(frame)); emitir(IRTG(IR_RETF, IRT_PTR), trpt, trpc); @@ -2285,6 +2316,12 @@ static const BCIns *rec_setup_root(jit_State *J) J->maxslot = ra; pc++; break; + case BC_RET: + case BC_RET0: + case BC_RET1: + /* No bytecode range check for down-recursive root traces. */ + J->maxslot = ra + bc_d(ins); + break; case BC_FUNCF: /* No bytecode range check for root traces started by a hot call. */ J->maxslot = J->pt->numparams; diff --git a/src/lj_trace.c b/src/lj_trace.c index 0b55f717..246dc03c 100644 --- a/src/lj_trace.c +++ b/src/lj_trace.c @@ -357,6 +357,8 @@ static void trace_start(jit_State *J) if ((J->pt->flags & PROTO_NO_JIT)) { /* JIT disabled for this proto? */ if (J->parent == 0) { /* Lazy bytecode patching to disable hotcount events. */ + lua_assert(bc_op(*J->pc) == BC_FORL || bc_op(*J->pc) == BC_ITERL || + bc_op(*J->pc) == BC_LOOP || bc_op(*J->pc) == BC_FUNCF); setbc_op(J->pc, (int)bc_op(*J->pc)+(int)BC_ILOOP-(int)BC_LOOP); J->pt->flags |= PROTO_HAS_ILOOP; } @@ -416,10 +418,16 @@ static void trace_stop(jit_State *J) /* Patch bytecode of starting instruction in root trace. */ setbc_op(pc, (int)op+(int)BC_JLOOP-(int)BC_LOOP); setbc_d(pc, J->curtrace); + addroot: /* Add to root trace chain in prototype. */ J->cur.nextroot = pt->trace; pt->trace = (TraceNo1)J->curtrace; break; + case BC_RET: + case BC_RET0: + case BC_RET1: + *pc = BCINS_AD(BC_JLOOP, J->cur.snap[0].nslots, J->curtrace); + goto addroot; case BC_JMP: /* Patch exit branch in parent to side trace entry. */ lua_assert(J->parent != 0 && J->cur.root != 0); @@ -450,6 +458,21 @@ static void trace_stop(jit_State *J) ); } +/* Start a new root trace for down-recursion. */ +static int trace_downrec(jit_State *J) +{ + /* Restart recording at the return instruction. */ + lua_assert(J->pt != NULL); + lua_assert(bc_isret(bc_op(*J->pc))); + if (bc_op(*J->pc) == BC_RETM) + return 0; /* NYI: down-recursion with RETM. */ + J->parent = 0; + J->exitno = 0; + J->state = LJ_TRACE_RECORD; + trace_start(J); + return 1; +} + /* Abort tracing. */ static int trace_abort(jit_State *J) { @@ -463,7 +486,7 @@ static int trace_abort(jit_State *J) return 1; /* Retry ASM with new MCode area. */ } /* Penalize or blacklist starting bytecode instruction. */ - if (J->parent == 0) + if (J->parent == 0 && !bc_isret(bc_op(J->cur.startins))) penalty_pc(J, &gcref(J->cur.startpt)->pt, (BCIns *)J->startpc, e); if (J->curtrace) { /* Is there anything to abort? */ ptrdiff_t errobj = savestack(L, L->top-1); /* Stack may be resized. */ @@ -493,17 +516,29 @@ static int trace_abort(jit_State *J) J->curtrace = 0; } L->top--; /* Remove error object */ - if (e == LJ_TRERR_MCODEAL) + if (e == LJ_TRERR_DOWNREC) + return trace_downrec(J); + else if (e == LJ_TRERR_MCODEAL) lj_trace_flushall(L); return 0; } +/* Perform pending re-patch of a bytecode instruction. */ +static LJ_AINLINE void trace_pendpatch(jit_State *J, int force) +{ + if (LJ_UNLIKELY(J->patchpc) && (force || J->chain[IR_RETF])) { + *J->patchpc = J->patchins; + J->patchpc = NULL; + } +} + /* State machine for the trace compiler. Protected callback. */ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud) { jit_State *J = (jit_State *)ud; UNUSED(dummy); do { + retry: switch (J->state) { case LJ_TRACE_START: J->state = LJ_TRACE_RECORD; /* trace_start() may change state. */ @@ -512,6 +547,7 @@ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud) break; case LJ_TRACE_RECORD: + trace_pendpatch(J, 0); setvmstate(J2G(J), RECORD); lj_vmevent_send(L, RECORD, setintV(L->top++, J->curtrace); @@ -523,6 +559,7 @@ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud) break; case LJ_TRACE_END: + trace_pendpatch(J, 1); J->loopref = 0; if ((J->flags & JIT_F_OPT_LOOP) && J->cur.link == J->curtrace && J->framedepth + J->retdepth == 0) { @@ -551,8 +588,9 @@ static TValue *trace_state(lua_State *L, lua_CFunction dummy, void *ud) setintV(L->top++, (int32_t)LJ_TRERR_RECERR); /* fallthrough */ case LJ_TRACE_ERR: + trace_pendpatch(J, 1); if (trace_abort(J)) - break; /* Retry. */ + goto retry; setvmstate(J2G(J), INTERP); J->state = LJ_TRACE_IDLE; lj_dispatch_update(J2G(J)); @@ -627,6 +665,7 @@ int LJ_FASTCALL lj_trace_exit(jit_State *J, void *exptr) lua_State *L = J->L; ExitDataCP exd; int errcode; + const BCIns *pc; exd.J = J; exd.exptr = exptr; errcode = lj_vm_cpcall(L, NULL, &exd, trace_exit_cp); @@ -651,8 +690,21 @@ int LJ_FASTCALL lj_trace_exit(jit_State *J, void *exptr) } ); - trace_hotside(J, exd.pc); - setcframe_pc(cframe_raw(L->cframe), exd.pc); + pc = exd.pc; + trace_hotside(J, pc); + if (bc_op(*pc) == BC_JLOOP) { + BCIns *retpc = &J->trace[bc_d(*pc)]->startins; + if (bc_isret(bc_op(*retpc))) { + if (J->state == LJ_TRACE_RECORD) { + J->patchins = *pc; + J->patchpc = (BCIns *)pc; + *J->patchpc = *retpc; + } else { + pc = retpc; + } + } + } + setcframe_pc(cframe_raw(L->cframe), pc); return 0; } diff --git a/src/lj_traceerr.h b/src/lj_traceerr.h index db7668fe..7b0dd813 100644 --- a/src/lj_traceerr.h +++ b/src/lj_traceerr.h @@ -22,6 +22,7 @@ TREDEF(LUNROLL, "loop unroll limit reached") TREDEF(BADTYPE, "bad argument type") TREDEF(CJITOFF, "call to JIT-disabled function") TREDEF(CUNROLL, "call unroll limit reached") +TREDEF(DOWNREC, "down-recursion, restarting") TREDEF(NYIVF, "NYI: vararg function") TREDEF(NYICF, "NYI: C function %p") TREDEF(NYIFF, "NYI: FastFunc %s")