From c84b1760629a583458df681ec02897f780588c20 Mon Sep 17 00:00:00 2001 From: fsfod Date: Fri, 4 Dec 2015 01:35:45 +0000 Subject: [PATCH] Initial support for intrinsics on x86/x64 interpreter only --- src/Makefile | 2 +- src/lib_ffi.c | 3 + src/lj_arch.h | 6 + src/lj_asm.c | 305 ++++++++++++++++++++ src/lj_ccall.c | 6 +- src/lj_clib.c | 8 + src/lj_cparse.c | 140 ++++++++- src/lj_crecord.c | 13 +- src/lj_ctype.c | 17 ++ src/lj_ctype.h | 32 ++- src/lj_emit_x86.h | 314 ++++++++++++++++++++- src/lj_errmsg.h | 4 + src/lj_intrinsic.c | 526 ++++++++++++++++++++++++++++++++++ src/lj_intrinsic.h | 113 ++++++++ src/lj_jit.h | 3 + src/lj_target_x86.h | 27 ++ src/lj_trace.c | 3 + tests/debug.sh | 7 + tests/intrinsic_spec.lua | 350 +++++++++++++++++++++++ tests/jit_tester.lua | 399 ++++++++++++++++++++++++++ tests/runtests.bat | 15 + tests/runtests.lua | 124 ++++++++ tests/runtests.sh | 4 + tests/telescope.lua | 594 +++++++++++++++++++++++++++++++++++++++ tests/tracetracker.lua | 224 +++++++++++++++ tests/tsc | 304 ++++++++++++++++++++ 26 files changed, 3524 insertions(+), 19 deletions(-) create mode 100644 src/lj_intrinsic.c create mode 100644 src/lj_intrinsic.h create mode 100644 tests/debug.sh create mode 100644 tests/intrinsic_spec.lua create mode 100644 tests/jit_tester.lua create mode 100644 tests/runtests.bat create mode 100644 tests/runtests.lua create mode 100644 tests/runtests.sh create mode 100644 tests/telescope.lua create mode 100644 tests/tracetracker.lua create mode 100644 tests/tsc diff --git a/src/Makefile b/src/Makefile index ad80642b..5458b025 100644 --- a/src/Makefile +++ b/src/Makefile @@ -488,7 +488,7 @@ LJLIB_C= $(LJLIB_O:.o=.c) LJCORE_O= lj_gc.o lj_err.o lj_char.o lj_bc.o lj_obj.o lj_buf.o \ lj_str.o lj_tab.o lj_func.o lj_udata.o lj_meta.o lj_debug.o \ lj_state.o lj_dispatch.o lj_vmevent.o lj_vmmath.o lj_strscan.o \ - lj_strfmt.o lj_strfmt_num.o lj_api.o lj_profile.o \ + lj_strfmt.o lj_strfmt_num.o lj_api.o lj_profile.o lj_intrinsic.o\ lj_lex.o lj_parse.o lj_bcread.o lj_bcwrite.o lj_load.o \ lj_ir.o lj_opt_mem.o lj_opt_fold.o lj_opt_narrow.o \ lj_opt_dce.o lj_opt_loop.o lj_opt_split.o lj_opt_sink.o \ diff --git a/src/lib_ffi.c b/src/lib_ffi.c index b3da1f3e..cff4e0fd 100644 --- a/src/lib_ffi.c +++ b/src/lib_ffi.c @@ -33,6 +33,8 @@ #include "lj_ff.h" #include "lj_lib.h" +#include "lj_intrinsic.h" + /* -- C type checks ------------------------------------------------------- */ /* Check first argument for a C type and returns its ID. */ @@ -849,6 +851,7 @@ LUALIB_API int luaopen_ffi(lua_State *L) { CTState *cts = lj_ctype_init(L); settabV(L, L->top++, (cts->miscmap = lj_tab_new(L, 0, 1))); + lj_intrinsic_init(L); cts->finalizer = ffi_finalizer(L); LJ_LIB_REG(L, NULL, ffi_meta); /* NOBARRIER: basemt is a GC root. */ diff --git a/src/lj_arch.h b/src/lj_arch.h index 903d6c64..c04bf74b 100644 --- a/src/lj_arch.h +++ b/src/lj_arch.h @@ -553,6 +553,12 @@ #define LJ_64 1 #endif +#if defined(LJ_TARGET_X86ORX64) && LJ_HASJIT && LJ_HASFFI +#define LJ_HASINTRINSICS 1 +#else +#define LJ_HASINTRINSICS 0 +#endif + #ifndef LJ_TARGET_UNALIGNED #define LJ_TARGET_UNALIGNED 0 #endif diff --git a/src/lj_asm.c b/src/lj_asm.c index c4c5dfdd..ec5e7c7f 100644 --- a/src/lj_asm.c +++ b/src/lj_asm.c @@ -16,6 +16,9 @@ #include "lj_frame.h" #if LJ_HASFFI #include "lj_ctype.h" +#if LJ_HASINTRINSICS +#include "lj_intrinsic.h" +#endif #endif #include "lj_ir.h" #include "lj_jit.h" @@ -27,6 +30,7 @@ #include "lj_asm.h" #include "lj_dispatch.h" #include "lj_vm.h" +#include "lj_err.h" #include "lj_target.h" #ifdef LUA_USE_ASSERT @@ -138,6 +142,18 @@ static LJ_AINLINE void checkmclim(ASMState *as) #endif } +MCode *asm_mcode(ASMState *as, void *mc, MSize sz) +{ + lua_assert(sz != 0 && sz < 0xffff && mc != NULL); + as->mcp -= sz; +#ifdef LUA_USE_ASSERT + as->mcp_prev = as->mcp; +#endif + if (LJ_UNLIKELY(as->mcp < as->mclim)) asm_mclimit(as); + memcpy(as->mcp, mc, sz); + return as->mcp; +} + #ifdef RID_NUM_KREF #define ra_iskref(ref) ((ref) < RID_NUM_KREF) #define ra_krefreg(ref) ((Reg)(RID_MIN_KREF + (Reg)(ref))) @@ -2408,6 +2424,295 @@ void lj_asm_trace(jit_State *J, GCtrace *T) lj_mcode_sync(T->mcode, origtop); } +#if LJ_HASINTRINSICS + +static void lj_asm_setup_intrins(jit_State *J, ASMState *as) +{ + MCodeArea *mcarea = &J->mcarea_intrins; + memset(as, 0, sizeof(ASMState)); + as->J = J; + as->flags = J->flags; + as->freeset = ~0; + /* Switch to the intrinsic machine code area */ + J->curmcarea = mcarea; + + /* Reserve MCode memory. */ + as->mctop = lj_mcode_reserve(J, &as->mcbot); + as->mcp = mcarea->top; +#ifdef LUA_USE_ASSERT + as->mcp_prev = as->mcp; +#endif + as->mclim = mcarea->bot + MCLIM_REDZONE; + as->mcloop = NULL; + as->flagmcp = NULL; +} + +typedef struct IntrinBuildState { + uint8_t in[LJ_INTRINS_MAXREG], out[LJ_INTRINS_MAXREG]; + RegSet inset, outset, modregs; + uint32_t spadj, contexspill, contexofs; + uint8_t outcontext; +} IntrinBuildState; + +static void intrins_setup(CIntrinsic *intrins, IntrinBuildState *info) +{ + MSize offset = 0, i; + memcpy(info->in, intrins->in, LJ_INTRINS_MAXREG); + memcpy(info->out, intrins->out, LJ_INTRINS_MAXREG); + + info->contexofs = -1; + + for (i = 0; i < intrins->insz; i++) { + Reg r = reg_rid(info->in[i]); + + if (reg_isgpr(info->in[i])) { + if (r == RID_CONTEXT) { + /* Save the offset in the input context so we can load it last */ + info->contexofs = offset; + } + offset += sizeof(intptr_t); + } + + rset_set(info->inset, r); + } + + for (i = 0; i < intrins->outsz; i++) { + rset_set(info->outset, reg_rid(info->out[i])); + } + + /* TODO: dynamic output context register selection */ + info->outcontext = RID_OUTCONTEXT; + info->modregs |= info->outset|info->inset; +} + +static void intrins_loadregs(ASMState *as, CIntrinsic *intrins, IntrinBuildState *info) +{ + uint32_t gpr = 0, fpr = 0, i; + + as->freeset = ~info->inset; + rset_clear(as->freeset, RID_CONTEXT); + + /* If the output content is needed but not spilled don't use it as a scratch */ + if (intrins->outsz > 0 && !info->contexspill) { + rset_clear(as->freeset, info->outcontext); + } + + /* Finally load the input register conflicting with the input context */ + if (rset_test(info->inset, RID_CONTEXT) && info->contexofs != -1) { + emit_loadofsirt(as, IRT_INTP, RID_CONTEXT, RID_CONTEXT, info->contexofs); + } + + /* Move values out the context into there respective input registers */ + for (i = 0; i < intrins->insz; i++) { + uint32_t reg = info->in[i]; + Reg r = reg_rid(reg); + + if (reg_isgpr(reg)) { + if (r != RID_CONTEXT) + emit_loadofsirt(as, IRT_INTP, r, RID_CONTEXT, gpr * sizeof(intptr_t)); + gpr++; + } else { + emit_loadfpr(as, reg, RID_CONTEXT, + offsetof(RegContext, fpr) + (sizeof(double) * fpr)); + fpr++; + } + checkmclim(as); + + if (r != RID_CONTEXT) + rset_set(as->freeset, r); + } +} + +static void intrins_saveregs(ASMState *as, CIntrinsic *intrins, IntrinBuildState *info) +{ + MSize offset = 0, i; + Reg outcontext = info->outcontext; + + /* All registers start as free because were emitting backwards */ + as->freeset = RSET_INIT; + rset_clear(as->freeset, outcontext); + + /* Save output registers into the context */ + for (i = 0; i < intrins->outsz; i++) { + uint32_t reg = info->out[i]; + Reg r = reg_rid(reg); + + if (r != outcontext) { + /* Exclude this register from the scratch set since its now live */ + rset_clear(as->freeset, r); + + if (r < RID_MAX_GPR) { + emit_savegpr(as, reg, outcontext, offset); + } else { + emit_savefpr(as, reg, outcontext, offset); + } + } + checkmclim(as); + offset += sizeof(TValue); + } +} + +/* +** Stack spill slots and gpr slots in the context are always the size of a native pointer +** The output context register is always spilled to a fixed stack offset +** The output context by default is the Lua stack. Signed 32 bit out register +** are directly written to it without any cdata boxing. +** Vectors are always passed in as pointers including the output context where +** the vector cdata is precreated and written to the stack before the wrapper is called. +** Vector are assumed tobe always unaligned for now when emitting load/stores +*/ + +static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *state) +{ + ASMState as_; + ASMState *as = &as_; + IntrinBuildState info; + AsmHeader *hdr; + MCode *asmofs = NULL, *origtop; + int spadj = 0; + + lj_asm_setup_intrins(J, as); + origtop = as->mctop; + + memset(&info, 0, sizeof(info)); + info.modregs = state->mod; + intrins_setup(intrins, &info); + + /* Used for picking scratch register when loading or saving boxed values */ + as->modset = info.modregs|RID_CONTEXT; + + /* Check if we need to save the register used to hold the pointer of the output context */ + if (intrins->outsz > 0) { + info.contexspill = rset_test(as->modset, info.outcontext); + ra_modified(as, info.outcontext); + } + +restart: + if (info.contexspill || rset_test(info.outset, info.outcontext)) { + /* add some extra space for context spill and temp spill */ + spadj = sizeof(intptr_t)*2; + } + + emit_epilogue(as, spadj, info.modregs, intrins->outsz); + + /* If one of the output registers was the same as the outcontext we will + * of saved the output value to the stack earlier, now save it into context + */ + if (rset_test(info.outset, info.outcontext) && intrins->outsz > 0) { + MSize offset = 0, i; + /* Don't use the context register as a scratch register */ + rset_clear(as->freeset, info.outcontext); + rset_clear(as->freeset, RID_RET); + + for (i = 0; i < intrins->outsz; i++) { + if (reg_rid(intrins->out[i]) == info.outcontext) { + break; + } + offset += sizeof(TValue); + } + + emit_savegpr(as, reg_setrid(intrins->out[i], RID_RET), RID_OUTCONTEXT, offset); + emit_loadofsirt(as, IRT_INTP, RID_RET, RID_SP, TEMPSPILL); + } + + /* Save output registers into the context */ + intrins_saveregs(as, intrins, &info); + + /* Restore the context register if it was overwritten by the intrinsic or was + ** an input register + */ + if (rset_test(info.modregs, info.outcontext) && intrins->outsz > 0) { + emit_loadofsirt(as, IRT_INTP, info.outcontext, RID_SP, CONTEXTSPILL); + } + + /* Save the value of the output context register if it listed as an output register */ + if (rset_test(info.outset, info.outcontext)) { + emit_storeofsirt(as, IRT_INTP, info.outcontext, RID_SP, TEMPSPILL); + } + + /* Append the user supplied machine code */ + asmofs = asm_mcode(as, state->target, state->targetsz); + + /* Move values out the context into there respective input registers */ + intrins_loadregs(as, intrins, &info); + + /* Save the output context pointer if it will be overwritten */ + if (rset_test(info.modregs, info.outcontext) && intrins->outsz > 0) { + emit_storeofsirt(as, IRT_INTP, info.outcontext, RID_SP, CONTEXTSPILL); + } + + emit_prologue(as, spadj, info.modregs); + + /* Check if we used any extra non scratch register needed for loading the + * pointer of boxed values + */ + if ((as->modset & ~RSET_SCRATCH) & ~info.modregs) { + info.modregs |= (as->modset & ~RSET_SCRATCH); + as->mcp = origtop; + /* Have to restart so we can emit the epilogue with the missing reg saves */ + goto restart; + } + + hdr = ((AsmHeader*)as->mcp)-1; + memset(hdr, 0, sizeof(AsmHeader)); + hdr->totalzs = (uint32_t)(origtop-as->mcp); + + lua_assert((asmofs-as->mcp) < 0xffff); + hdr->asmofs = (uint16_t)(asmofs-as->mcp); + hdr->asmsz = state->targetsz; + + as->mcp = (MCode*)hdr; + + lj_mcode_sync(as->mcp, origtop); + lj_mcode_commit(J, as->mcp); + + /* Switch back the current machine code area to the jit one*/ + J->curmcarea = &J->mcarea; + + /* Return a pointer to the start of the wrapper */ + state->wrapper = (hdr+1); +} + +static TValue *wrap_intrins_cp(lua_State *L, lua_CFunction dummy, void *ud) +{ + IntrinWrapState *state = (IntrinWrapState*)ud; + UNUSED(dummy); + wrap_intrins(L2J(L), state->intrins, state); + return NULL; +} + +int lj_asm_intrins(lua_State *L, IntrinWrapState *state) +{ + int errcode = 0; + + while ((errcode = lj_vm_cpcall(L, NULL, state, wrap_intrins_cp)) != 0) { + jit_State *J = L2J(L); + + lj_mcode_abort(J); + J->curmcarea = &J->mcarea; + + if (errcode == LUA_ERRRUN){ + if (tvisnumber(L->top-1)) { /* Trace error? */ + TraceError trerr = (TraceError)numberVint(L->top-1); + if (trerr == LJ_TRERR_MCODELM) { + /* mcarea reallocation try again */ + L->top--; + continue; + } + return -(trerr+2); + } else { + return -1; + } + } else { + lj_err_throw(L, errcode); + } + } + + return 0; +} + +#endif + #undef IR #endif diff --git a/src/lj_ccall.c b/src/lj_ccall.c index 25e938cb..7b1f5d88 100644 --- a/src/lj_ccall.c +++ b/src/lj_ccall.c @@ -1141,6 +1141,8 @@ static int ccall_get_results(lua_State *L, CTState *cts, CType *ct, return lj_cconv_tv_ct(cts, ctr, 0, L->top-1, sp); } +int lj_intrinsic_call(CTState *cts, CType *ct); + /* Call C function. */ int lj_ccall_func(lua_State *L, GCcdata *cd) { @@ -1151,7 +1153,9 @@ int lj_ccall_func(lua_State *L, GCcdata *cd) sz = ct->size; ct = ctype_rawchild(cts, ct); } - if (ctype_isfunc(ct->info)) { + if (ctype_isintrinsic(ct->info)) { + return lj_intrinsic_call(cts, ct); + }else if (ctype_isfunc(ct->info)) { CCallState cc; int gcsteps, ret; cc.func = (void (*)(void))cdata_getptr(cdataptr(cd), sz); diff --git a/src/lj_clib.c b/src/lj_clib.c index a7df719a..8bfdef3a 100644 --- a/src/lj_clib.c +++ b/src/lj_clib.c @@ -17,6 +17,7 @@ #include "lj_cdata.h" #include "lj_clib.h" #include "lj_strfmt.h" +#include "lj_intrinsic.h" /* -- OS-specific functions ----------------------------------------------- */ @@ -355,6 +356,13 @@ TValue *lj_clib_index(lua_State *L, CLibrary *cl, GCstr *name) setnumV(tv, (lua_Number)(uint32_t)ct->size); else setintV(tv, (int32_t)ct->size); + } else if(ctype_isintrinsic(ct->info)) { +#if LJ_HASINTRINSICS + /* TODO: maybe move to ASM namespace only */ + setcdataV(L, tv, lj_intrinsic_createffi(cts, ct)); +#else + lj_err_callermsg(L, "Intrinsics disabled"); +#endif } else { const char *sym = clib_extsym(cts, ct, name); #if LJ_TARGET_WINDOWS diff --git a/src/lj_cparse.c b/src/lj_cparse.c index a5a15da0..44f808d9 100644 --- a/src/lj_cparse.c +++ b/src/lj_cparse.c @@ -17,6 +17,7 @@ #include "lj_char.h" #include "lj_strscan.h" #include "lj_strfmt.h" +#include "lj_intrinsic.h" /* ** Important note: this is NOT a validating C parser! This is a minimal @@ -877,9 +878,31 @@ static CTypeID cp_decl_intern(CPState *cp, CPDecl *decl) sib = ct->sib; /* Next line may reallocate the C type table. */ fid = lj_ctype_new(cp->cts, &fct); csize = CTSIZE_INVALID; - fct->info = cinfo = info + id; + fct->info = info; fct->size = size; fct->sib = sib; + + if (ctype_isintrinsic(info)) { +#if LJ_HASINTRINSICS + CTypeID1 cid; + /* Don't overwrite the any attached register lists */ + if (ctype_cid(info) == 0) { + fct->info = info + id; + } + + cid = lj_intrinsic_fromcdef(cp->L, fid, decl->redir, decl->bits); + if (cid == 0) + cp_err(cp, LJ_ERR_FFI_INVTYPE); + decl->redir = NULL; + fct = ctype_get(cp->cts, fid); + fct->info = (info&0xffff0000) + cid; +#else + decl->redir = NULL; +#endif + } else { + fct->info = info + id; + } + cinfo = fct->info; id = fid; } else if (ctype_isattrib(info)) { if (ctype_isxattrib(info, CTA_QUAL)) @@ -1179,6 +1202,117 @@ static void cp_decl_msvcattribute(CPState *cp, CPDecl *decl) cp_check(cp, ')'); } +#if LJ_HASINTRINSICS + +static void cp_decl_mcode(CPState *cp, CPDecl *decl) +{ + /* Check were declared after a function definition */ + if (decl->top == 0) { + cp_err(cp, LJ_ERR_FFI_INVTYPE); + } else { + CTInfo info = decl->stack[decl->top-1].info; + if (!ctype_isfunc(info) || (info & CTF_VARARG)) { + cp_err(cp, LJ_ERR_FFI_INVTYPE); + } + } + cp_next(cp); + cp_check(cp, '('); + + if (cp->tok != CTOK_STRING) + cp_err_token(cp, CTOK_STRING); + /* Save the opcode/mode string for later parsing and validation */ + decl->redir = cp->str; + + cp_next(cp); + cp_check(cp, ')'); + /* Mark the function as an intrinsic */ + decl->stack[decl->top-1].info |= CTF_INTRINS; +} + +static void cp_reglist(CPState *cp, CPDecl *decl) +{ + cp_next(cp); + cp_check(cp, '('); + CTypeID lastid = 0, anchor = 0; + CType *ct; + int listid = 0; + + if (cp->tok != CTOK_IDENT) + cp_err_token(cp, CTOK_IDENT); + + if (strcmp(strdata(cp->str), "mod") == 0) { + listid = 1; + } else if (strcmp(strdata(cp->str), "out") == 0) { + listid = 2; + } else { + cp_errmsg(cp, CTOK_STRING, LJ_ERR_FFI_INVTYPE); + } + + cp_next(cp); + cp_check(cp, ','); + + if (listid == 1) { + uint32_t rset = 0; + do { + if (cp->tok != CTOK_IDENT) + cp_err_token(cp, CTOK_IDENT); + + int reg = lj_intrinsic_getreg(cp->cts, cp->str); + + if (reg == -1) { + /* TODO: register error */ + cp_errmsg(cp, 0, LJ_ERR_FFI_INVTYPE); + } + + rset |= 1 << reg_rid(reg); + cp_next(cp); + } while (cp_opt(cp, ',')); + + lastid = lj_ctype_new(cp->cts, &ct); + ct->info = CTINFO(CT_ATTRIB, 0); + ct->size = rset; + + decl->stack[decl->top-1].size |= lastid << 16; + } else { + do { + CPDecl decl; + CTypeID ctypeid = 0, fieldid; + + cp_decl_spec(cp, &decl, CDF_REGISTER); + decl.mode = CPARSE_MODE_DIRECT|CPARSE_MODE_ABSTRACT; + cp_declarator(cp, &decl); + ctypeid = cp_decl_intern(cp, &decl); + ct = ctype_raw(cp->cts, ctypeid); + + if (ctype_isvoid(ct->info) || ctype_isstruct(ct->info) || ctype_isfunc(ct->info)) { + cp_errmsg(cp, 0, LJ_ERR_FFI_INVTYPE); + } else if (ctype_isrefarray(ct->info)) { + ctypeid = lj_ctype_intern(cp->cts, + CTINFO(CT_PTR, CTALIGN_PTR|ctype_cid(ct->info)), CTSIZE_PTR); + } + + /* Add new parameter. */ + fieldid = lj_ctype_new(cp->cts, &ct); + /* Type must have a register name after it */ + if (!decl.name) cp_err_token(cp, CTOK_IDENT); + ctype_setname(ct, decl.name); + ct->info = CTINFO(CT_FIELD, ctypeid); + ct->size = 0; + + if (anchor) + ctype_get(cp->cts, lastid)->sib = fieldid; + else + anchor = fieldid; + lastid = fieldid; + } while (cp_opt(cp, ',')); + + decl->stack[decl->top-1].info = (decl->stack[decl->top-1].info & 0xffff0000) | anchor; + } + + cp_check(cp, ')'); +} +#endif + /* Parse declaration attributes (and common qualifiers). */ static void cp_decl_attributes(CPState *cp, CPDecl *decl) { @@ -1190,6 +1324,10 @@ static void cp_decl_attributes(CPState *cp, CPDecl *decl) case CTOK_EXTENSION: break; /* Ignore. */ case CTOK_ATTRIBUTE: cp_decl_gccattribute(cp, decl); continue; case CTOK_ASM: cp_decl_asm(cp, decl); continue; +#if LJ_HASINTRINSICS + case CTOK_REGLIST: cp_reglist(cp, decl); continue; + case CTOK_MCODE: cp_decl_mcode(cp, decl); continue; +#endif case CTOK_DECLSPEC: cp_decl_msvcattribute(cp, decl); continue; case CTOK_CCDECL: #if LJ_TARGET_X86 diff --git a/src/lj_crecord.c b/src/lj_crecord.c index d425686d..49a67f57 100644 --- a/src/lj_crecord.c +++ b/src/lj_crecord.c @@ -32,6 +32,7 @@ #include "lj_crecord.h" #include "lj_dispatch.h" #include "lj_strfmt.h" +#include "lj_intrinsic.h" /* Some local macros to save typing. Undef'd at the end. */ #define IR(ref) (&J->cur.ir[(ref)]) @@ -42,6 +43,14 @@ #define emitconv(a, dt, st, flags) \ emitir(IRT(IR_CONV, (dt)), (a), (st)|((dt) << 5)|(flags)) + +#define MKREGKIND_IT(name, it, ct) it, + +uint8_t regkind_it[16] = { + RKDEF_GPR(MKREGKIND_IT) + RKDEF_FPR(MKREGKIND_IT) +}; + /* -- C type checks ------------------------------------------------------- */ static GCcdata *argv2cdata(jit_State *J, TRef tr, cTValue *o) @@ -1202,7 +1211,9 @@ static int crec_call(jit_State *J, RecordFFData *rd, GCcdata *cd) tp = (LJ_64 && ct->size == 8) ? IRT_P64 : IRT_P32; ct = ctype_rawchild(cts, ct); } - if (ctype_isfunc(ct->info)) { + if (ctype_isintrinsic(ct->info)) { + lj_trace_err(J, LJ_TRERR_NYICALL); + }else if (ctype_isfunc(ct->info)) { TRef func = emitir(IRT(IR_FLOAD, tp), J->base[0], IRFL_CDATA_PTR); CType *ctr = ctype_rawchild(cts, ct); IRType t = crec_ct2irt(cts, ctr); diff --git a/src/lj_ctype.c b/src/lj_ctype.c index 0ea89c74..ed9f9fb7 100644 --- a/src/lj_ctype.c +++ b/src/lj_ctype.c @@ -15,6 +15,7 @@ #include "lj_ctype.h" #include "lj_ccallback.h" #include "lj_buf.h" +#include "lj_intrinsic.h" /* -- C type definitions -------------------------------------------------- */ @@ -93,6 +94,8 @@ _("asm", 0, CTOK_ASM) \ _("__asm", 0, CTOK_ASM) \ _("__asm__", 0, CTOK_ASM) \ + _("__mcode", 0, CTOK_MCODE) \ + _("__reglist", 0, CTOK_REGLIST) \ /* MSVC Attributes. */ \ _("__declspec", 0, CTOK_DECLSPEC) \ _("__cdecl", CTCC_CDECL, CTOK_CCDECL) \ @@ -142,6 +145,14 @@ CTKWDEF(CTKWNAMEDEF) #define CTTYPETAB_MIN 128 #endif +#define MKREGKIND_CT(name, it, ct) ct, + +/* Default ctypes for each register kinds */ +CTypeID1 regkind_ct[16] = { + RKDEF_GPR(MKREGKIND_CT) + RKDEF_FPR(MKREGKIND_CT) +}; + /* -- C type interning ---------------------------------------------------- */ #define ct_hashtype(info, size) (hashrot(info, size) & CTHASH_MASK) @@ -522,10 +533,14 @@ static void ctype_repr(CTRepr *ctr, CTypeID id) } break; case CT_FUNC: + if (ctype_isintrinsic(info)) + ctype_preplit(ctr, "Intrinsic "); ctr->needsp = 1; if (ptrto) { ptrto = 0; ctype_prepc(ctr, '('); ctype_appc(ctr, ')'); } ctype_appc(ctr, '('); ctype_appc(ctr, ')'); + if (ctype_isintrinsic(info)) + return; break; default: lua_assert(0); @@ -618,6 +633,7 @@ CTState *lj_ctype_init(lua_State *L) if (!ctype_isenum(info)) ctype_addtype(cts, ct, id); } } + setmref(G(L)->ctype_state, cts); return cts; } @@ -630,6 +646,7 @@ void lj_ctype_freestate(global_State *g) lj_ccallback_mcode_free(cts); lj_mem_freevec(g, cts->tab, cts->sizetab, CType); lj_mem_freevec(g, cts->cb.cbid, cts->cb.sizeid, CTypeID1); + lj_mem_freevec(g, cts->intr.tab, cts->intr.sizetab, CIntrinsic); lj_mem_freet(g, cts); } } diff --git a/src/lj_ctype.h b/src/lj_ctype.h index 0c220a88..c3b081dc 100644 --- a/src/lj_ctype.h +++ b/src/lj_ctype.h @@ -72,6 +72,7 @@ LJ_STATIC_ASSERT(((int)CT_STRUCT & (int)CT_ARRAY) == CT_STRUCT); #define CTF_VECTOR 0x08000000u /* Vector: ARRAY. */ #define CTF_COMPLEX 0x04000000u /* Complex: ARRAY. */ #define CTF_UNION 0x00800000u /* Union: STRUCT. */ +#define CTF_INTRINS 0x04000000u /* Intrinsic: FUNC. */ #define CTF_VARARG 0x00800000u /* Vararg: FUNC. */ #define CTF_SSEREGPARM 0x00400000u /* SSE register parameters: FUNC. */ @@ -170,6 +171,30 @@ typedef LJ_ALIGN(8) struct CCallback { MSize slot; /* Current callback slot. */ } CCallback; +typedef int (LJ_FASTCALL *IntrinsicWrapper)(void *incontext, void* outcontext); + +typedef struct CIntrinsic { + IntrinsicWrapper wrapped; + uint8_t in[8]; + union { + uint8_t out[8]; + struct { + uint8_t oregs[4]; + uint32_t opcode; + }; + }; + uint8_t insz; + uint8_t outsz; + uint16_t flags; + CTypeID1 id; +} CIntrinsic; + +typedef struct IntrinsicState { + CIntrinsic* tab; /* Intrinsic descriptor table. */ + MSize sizetab; /* Size of intrinsic table. */ + MSize top; /* Current top of Intrinsic table. */ +} IntrinsicState; + /* C type state. */ typedef struct CTState { CType *tab; /* C type table. */ @@ -179,6 +204,7 @@ typedef struct CTState { global_State *g; /* Global state. */ GCtab *finalizer; /* Map of cdata to finalizer. */ GCtab *miscmap; /* Map of -CTypeID to metatable and cb slot to func. */ + IntrinsicState intr; /* Intrinsic descriptor table. */ CCallback cb; /* Temporary callback state. */ CTypeID1 hash[CTHASH_SIZE]; /* Hash anchors for C type table. */ } CTState; @@ -246,6 +272,9 @@ typedef struct CTState { (((info) & (CTMASK_NUM|CTATTRIB(CTMASK_ATTRIB))) == \ CTINFO(CT_ATTRIB, CTATTRIB(at))) +#define ctype_isintrinsic(info) \ + (((info) & (CTMASK_NUM|CTF_INTRINS)) == CTINFO(CT_FUNC, CTF_INTRINS)) + /* Target-dependent sizes and alignments. */ #if LJ_64 #define CTSIZE_PTR 8 @@ -344,7 +373,7 @@ CTTYDEF(CTTYIDDEF) CDSDEF(_) _(EXTENSION) _(ASM) _(ATTRIBUTE) \ _(DECLSPEC) _(CCDECL) _(PTRSZ) \ _(STRUCT) _(UNION) _(ENUM) \ - _(SIZEOF) _(ALIGNOF) + _(SIZEOF) _(MCODE) _(REGLIST) _(ALIGNOF) /* C token numbers. */ enum { @@ -387,6 +416,7 @@ static LJ_AINLINE CTState *ctype_cts(lua_State *L) #define LJ_CTYPE_SAVE(cts) CTState savects_ = *(cts) #define LJ_CTYPE_RESTORE(cts) \ ((cts)->top = savects_.top, \ + (cts)->intr.top = savects_.intr.top, \ memcpy((cts)->hash, savects_.hash, sizeof(savects_.hash))) /* Check C type ID for validity when assertions are enabled. */ diff --git a/src/lj_emit_x86.h b/src/lj_emit_x86.h index b3dc4ea5..85c2196e 100644 --- a/src/lj_emit_x86.h +++ b/src/lj_emit_x86.h @@ -422,6 +422,26 @@ static void emit_loadk64(ASMState *as, Reg r, IRIns *ir) } } +static void emit_push(ASMState *as, Reg r) +{ + if (r < 8) { + *--as->mcp = XI_PUSH + r; + } else { + *--as->mcp = XI_PUSH + r; + *--as->mcp = 0x41; + } +} + +static void emit_pop(ASMState *as, Reg r) +{ + if (r < 8) { + *--as->mcp = XI_POP + r; + } else { + *--as->mcp = XI_POP + r; + *--as->mcp = 0x41; + } +} + /* -- Emit control-flow instructions -------------------------------------- */ /* Label for short jumps. */ @@ -498,14 +518,15 @@ static void emit_jmp(ASMState *as, MCode *target) } /* call target */ -static void emit_call_(ASMState *as, MCode *target) +static void emit_call_(ASMState *as, MCode *target, Reg temp) { MCode *p = as->mcp; #if LJ_64 if (target-p != (int32_t)(target-p)) { /* Assumes RID_RET is never an argument to calls and always clobbered. */ - emit_rr(as, XO_GROUP5, XOg_CALL, RID_RET); - emit_loadu64(as, RID_RET, (uint64_t)target); + if (temp == RID_NONE) temp = RID_RET; + emit_rr(as, XO_GROUP5, XOg_CALL, temp); + emit_loadu64(as, temp, (uint64_t)target); return; } #endif @@ -514,7 +535,7 @@ static void emit_call_(ASMState *as, MCode *target) as->mcp = p - 5; } -#define emit_call(as, f) emit_call_(as, (MCode *)(void *)(f)) +#define emit_call(as, f) emit_call_(as, (MCode *)(void *)(f), RID_NONE) /* -- Emit generic operations --------------------------------------------- */ @@ -537,22 +558,32 @@ static void emit_movrr(ASMState *as, IRIns *ir, Reg dst, Reg src) emit_rr(as, XO_MOVAPS, dst, src); } +#define emit_loadofs(as, ir, r, base, ofs) \ + emit_loadofsirt(as, irt_type(ir->t), r, base, ofs) + /* Generic load of register with base and (small) offset address. */ -static void emit_loadofs(ASMState *as, IRIns *ir, Reg r, Reg base, int32_t ofs) +static void emit_loadofsirt(ASMState *as, IRType irt, Reg r, Reg base, int32_t ofs) { - if (r < RID_MAX_GPR) - emit_rmro(as, XO_MOV, REX_64IR(ir, r), base, ofs); - else - emit_rmro(as, irt_isnum(ir->t) ? XO_MOVSD : XO_MOVSS, r, base, ofs); + if (r < RID_MAX_GPR) { + emit_rmro(as, XO_MOV, r | ((LJ_64 && ((IRT_IS64 >> irt) & 1)) ? REX_64 : 0), + base, ofs); + } else { + emit_rmro(as, irt == IRT_NUM ? XO_MOVSD : XO_MOVSS, r, base, ofs); + } } +#define emit_storeofs(as, ir, r, base, ofs) \ + emit_storeofsirt(as, irt_type(ir->t), r, base, ofs) + /* Generic store of register with base and (small) offset address. */ -static void emit_storeofs(ASMState *as, IRIns *ir, Reg r, Reg base, int32_t ofs) +static void emit_storeofsirt(ASMState *as, IRType irt, Reg r, Reg base, int32_t ofs) { - if (r < RID_MAX_GPR) - emit_rmro(as, XO_MOVto, REX_64IR(ir, r), base, ofs); - else - emit_rmro(as, irt_isnum(ir->t) ? XO_MOVSDto : XO_MOVSSto, r, base, ofs); + if (r < RID_MAX_GPR) { + emit_rmro(as, XO_MOVto, r | ((LJ_64 && ((IRT_IS64 >> irt) & 1)) ? REX_64 : 0), + base, ofs); + } else { + emit_rmro(as, irt == IRT_NUM ? XO_MOVSDto : XO_MOVSSto, r, base, ofs); + } } /* Add offset to pointer. */ @@ -571,3 +602,258 @@ static void emit_addptr(ASMState *as, Reg r, int32_t ofs) /* Prefer rematerialization of BASE/L from global_State over spills. */ #define emit_canremat(ref) ((ref) <= REF_BASE) +#if LJ_HASINTRINSICS + +#if LJ_64 +#define NEEDSFP 0 +#else +#define NEEDSFP 1 +#endif + +#define SPILLSTART (2*sizeof(intptr_t)) +#define TEMPSPILL (1*sizeof(intptr_t)) +#define CONTEXTSPILL (0) + +static int lj_popcnt(uint32_t i) +{ + i = i - ((i >> 1) & 0x55555555); + i = (i & 0x33333333) + ((i >> 2) & 0x33333333); + return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; +} + +#define align16(n) ((n + 16) & ~(16 - 1)) + +static int32_t alignsp(int32_t spadj, RegSet savereg) { + int32_t gprsave = lj_popcnt(savereg & RSET_GPR) * sizeof(intptr_t); + + if (NEEDSFP && rset_test(savereg, RID_EBP)) { + gprsave -= sizeof(intptr_t); + } +/* TODO: use shadow space/red zone on x64 to skip setting stack frame */ + spadj += gprsave; + +#if LJ_64 +#if LJ_ABI_WIN + if (savereg & RSET_FPR) { + spadj += lj_popcnt(savereg & RSET_FPR) * 16; + /* Add some slack in case the starting fpr save offset needs rounding up */ + spadj += 8; + } +#endif + if (spadj == 0) + return 0; + + spadj = align16(spadj); + /* No ebp pushed so the stack starts aligned to 8 bytes */ + if (!NEEDSFP)spadj += 8; +#endif + + return spadj; +} + +static void emit_prologue(ASMState *as, int spadj, RegSet modregs) +{ + int32_t offset, i; + RegSet savereg = modregs & ~RSET_SCRATCH; + + /* Save volatile registers after requested stack space */ + offset = spadj; + + /* save non scratch registers */ + for (i = RID_MIN_GPR; i < RID_MAX_GPR; i++) { + if (rset_test(savereg, i) && (i != RID_EBP || !NEEDSFP)) { + emit_rmro(as, XO_MOVto, i|REX_64, RID_SP, offset); + checkmclim(as); + offset += sizeof(intptr_t); + } + } + +#if LJ_ABI_WIN && LJ_64 + offset = align16(offset); + for (i = RID_MIN_FPR; i < RID_MAX_FPR; i++) { + if (rset_test(savereg, i)) { + emit_rmro(as, XO_MOVAPSto, i|REX_64, RID_SP, offset); + checkmclim(as); + offset += 16; + } + } +#endif + spadj = alignsp(spadj, savereg); + + if (spadj) { + emit_spsub(as, spadj); + } + + if (NEEDSFP) { + emit_rr(as, XO_MOV, RID_EBP|REX_64, RID_ESP); + emit_push(as, RID_EBP); + } +} + +static void emit_epilogue(ASMState *as, int spadj, RegSet modregs, int32_t ret) +{ + int32_t offset, i; + RegSet savereg = modregs & ~RSET_SCRATCH; + checkmclim(as); + + *--as->mcp = XI_RET; + if (NEEDSFP) + emit_pop(as, RID_EBP); + /* Save volatile registers after requested stack space */ + offset = spadj; + + spadj = alignsp(spadj, savereg); + + if (spadj != 0) { + emit_spsub(as, -spadj); + } + + as->mcp -= 4; + *(int32_t *)as->mcp = ret; + *--as->mcp = XI_MOVri + RID_RET; + + if (savereg == RSET_EMPTY) { + return; + } + + /* Restore non scratch registers */ + for (i = RID_MIN_GPR; i < RID_MAX_GPR; i++) { + if (rset_test(savereg, i) && (i != RID_EBP || !NEEDSFP)) { + emit_rmro(as, XO_MOV, i|REX_64, RID_SP, offset); + checkmclim(as); + offset += sizeof(intptr_t); + } + } + +#if LJ_ABI_WIN && LJ_64 + offset = align16(offset); + for (i = RID_MIN_FPR; i < RID_MAX_FPR; i++) { + if (rset_test(savereg, i)) { + emit_rmro(as, XO_MOVAPS, i|REX_64, RID_SP, offset); + checkmclim(as); + offset += 16; + } + } +#endif +} + +/* Trys to pick free register from the scratch or modified set first + * before resorting to register that will need tobe saved. + */ +static Reg intrinsic_scratch(ASMState *as, RegSet allow) +{ + RegSet pick = (as->freeset & allow) & (as->modset|RSET_SCRATCH); + Reg r; + + if (!pick) { + pick = as->freeset & allow; + + if (pick == 0) { + /* No free registers */ + lj_trace_err(as->J, LJ_TRERR_BADRA); + } + + r = rset_pickbot(pick); + as->modset |= RID2RSET(r); + } else { + r = rset_pickbot(pick); + } + + /* start from the bottom where most of the non spilled registers are */ + return r; +} + +static void emit_savegpr(ASMState *as, Reg reg, Reg base, int ofs) +{ + Reg temp, r = reg_rid(reg); + uint32_t kind = reg_kind(reg); + lua_assert(r < RID_NUM_GPR); + + if (kind == REGKIND_GPRI32) { +#if LJ_DUALNUM + emit_i32(as, LJ_TISNUM); + emit_rmro(as, XO_MOVmi, 0, base, ofs+4); + emit_rmro(as, XO_MOVto, r, base, ofs); +#else + temp = intrinsic_scratch(as, RSET_FPR); + emit_rmro(as, XO_MOVSDto, temp, base, ofs); + emit_mrm(as, XO_CVTSI2SD, temp, r); +#endif + return; + } + + if (kind == REGKIND_GPR64) { + r |= REX_64; + } + + temp = intrinsic_scratch(as, RSET_GPR); + /* Save the register into a cdata who's pointer is inside a TValue on the Lua stack */ + emit_rmro(as, XO_MOVto, r, temp, sizeof(GCcdata)); + emit_rmro(as, XO_MOV, temp, base, ofs); +} + +static void emit_loadfpr(ASMState *as, uint32_t reg, Reg base, int ofs) +{ + x86Op op = XO_MOVSD; + Reg r = reg_rid(reg)-RID_MIN_FPR; + uint32_t kind = reg_kind(reg); + lua_assert(r < RID_NUM_FPR); + + switch (kind) { + case REGKIND_FPR64: + op = XO_MOVSD; + break; + case REGKIND_FPR32: + op = XO_MOVSS; + break; + case REGKIND_V128: + op = XO_MOVUPS; + break; + } + + if (!rk_isvec(kind)) { + emit_rmro(as, op, r, base, ofs); + } else { + Reg temp = intrinsic_scratch(as, RSET_GPR); + emit_rmro(as, op, r, temp, 0); + + /* Load a pointer to the vector out of the input context */ + emit_rmro(as, XO_MOV, temp|REX_64, base, ofs); + } +} + +static void emit_savefpr(ASMState *as, Reg reg, Reg base, int ofs) +{ + x86Op op; + Reg r = reg_rid(reg)-RID_MIN_FPR; + uint32_t kind = reg_kind(reg) & 3; + lua_assert(r < RID_NUM_FPR); + + switch (kind) { + case REGKIND_FPR64: + op = XO_MOVSDto; + break; + case REGKIND_FPR32: + op = XO_MOVSDto; + break; + case REGKIND_V128: + op = XO_MOVUPSto; + break; + } + + if (!rk_isvec(kind)) { + emit_rmro(as, op, r, base, ofs); + if (kind == REGKIND_FPR32) { + emit_mrm(as, XO_CVTSS2SD, r, r); + } + } else { + Reg temp = intrinsic_scratch(as, RSET_GPR); + + /* Save the register into a cdata who's pointer is inside a TValue on the Lua stack */ + emit_rmro(as, op, r, temp, sizeof(GCcdata)); + emit_rmro(as, XO_MOV, temp, base, ofs); + } +} + +#endif + diff --git a/src/lj_errmsg.h b/src/lj_errmsg.h index 060a9f89..5881870e 100644 --- a/src/lj_errmsg.h +++ b/src/lj_errmsg.h @@ -179,8 +179,12 @@ ERRDEF(FFI_CBACKOV, "no support for callbacks on this OS") #else ERRDEF(FFI_CBACKOV, "too many callbacks") #endif +ERRDEF(FFI_BADREG, "bad register(%s) " LUA_QS " found in list " LUA_QS) +ERRDEF(FFI_REGOV, "register count for list " LUA_QS " exceeds the max register limit of %d") +ERRDEF(FFI_BADOPSTR, "bad opcode string " LUA_QS " %s") ERRDEF(FFI_NYIPACKBIT, "NYI: packed bit fields") ERRDEF(FFI_NYICALL, "NYI: cannot call this C function (yet)") +ERRDEF(FFI_INTRWRAP, "Failed to create interpreter wrapper for intrinsic(%s)") #endif #undef ERRDEF diff --git a/src/lj_intrinsic.c b/src/lj_intrinsic.c new file mode 100644 index 00000000..17e5869c --- /dev/null +++ b/src/lj_intrinsic.c @@ -0,0 +1,526 @@ +/* +** FFI Intrinsic system. +*/ + +#define LUA_CORE +#include "lj_arch.h" +#include "lj_tab.h" +#include "lj_err.h" +#include "lj_intrinsic.h" + +#if LJ_HASINTRINSICS + +#include "lj_lib.h" +#include "lj_err.h" +#include "lj_str.h" +#include "lj_char.h" +#include "lj_cdata.h" +#include "lj_cconv.h" +#include "lj_jit.h" +#include "lj_trace.h" +#include "lj_dispatch.h" +#include "lj_target.h" + +typedef enum RegFlags { + REGFLAG_64BIT = REGKIND_GPR64 << 6, /* 64 bit override */ + REGFLAG_BLACKLIST = 1 << 17, +}RegFlags; + +typedef struct RegEntry { + const char* name; + unsigned int slot; /* Slot and Info */ +}RegEntry; + +#define RIDENUM(name) RID_##name, + +#define MKREG(name) {#name, RID_##name}, +#define MKREGGPR(reg, name) {#name, RID_##reg}, +#define MKREG_GPR64(reg, name) {#name, REGFLAG_64BIT|RID_##reg}, + +#if LJ_64 +#define GPRDEF2(_) \ + _(EAX, eax) _(ECX, ecx) _(EDX, edx) _(EBX, ebx) _(ESP|REGFLAG_BLACKLIST, esp) \ + _(EBP, ebp) _(ESI, esi) _(EDI, edi) _(R8D, r8d) _(R9D, r9d) _(R10D, r10d) \ + _(R11D, r11d) _(R12D, r12d) _(R13D, r13d) _(R14D, r14d) _(R15D, r15d) + +#define GPRDEF_R64(_) \ + _(EAX, rax) _(ECX, rcx) _(EDX, rdx) _(EBX, rbx) _(ESP|REGFLAG_BLACKLIST, rsp) _(EBP, rbp) _(ESI, rsi) _(EDI, rdi) +#else +#define GPRDEF2(_) \ + _(EAX, eax) _(ECX, ecx) _(EDX, edx) _(EBX, ebx) _(ESP|REGFLAG_BLACKLIST, esp) _(EBP, ebp) _(ESI, esi) _(EDI, edi) +#endif + +RegEntry reglut[] = { + GPRDEF2(MKREGGPR) +#if LJ_64 + GPRDEF_R64(MKREG_GPR64) +#endif +}; + +static CTypeID register_intrinsic(lua_State *L, CIntrinsic* src, CType *func) +{ + CTState *cts = ctype_cts(L); + CType *ct; + CTypeID id; + CIntrinsic *intrins; + lua_assert(ctype_isintrinsic(func->info)); + + if (cts->intr.top+1 >= LJ_INTRINS_MAXID) lj_err_msg(cts->L, LJ_ERR_TABOV); + + if ((cts->intr.top+1) > cts->intr.sizetab) { + lj_mem_growvec(cts->L, cts->intr.tab, cts->intr.sizetab, LJ_INTRINS_MAXID, CIntrinsic); + } + + id = ctype_typeid(cts, func); + ct = func; + + /* Upper bits of size are used for modified link */ + ct->size = (ct->size & 0xffff0000) | cts->intr.top; + intrins = &cts->intr.tab[cts->intr.top++]; + memcpy(intrins, src, sizeof(CIntrinsic)); + intrins->id = id; + + return id; +} + +static int parse_fprreg(const char *name, uint32_t len) +{ + uint32_t rid = 0, kind = REGKIND_FPR64; + uint32_t pos = 3; + + if (len < 3 || name[0] != 'x' || + name[1] != 'm' || name[2] != 'm') + return -1; + + if (lj_char_isdigit((uint8_t)name[3])) { + rid = name[3] - '0'; + pos = 4; + + if (LJ_64 && lj_char_isdigit((uint8_t)name[4])) { + rid = rid*10; + rid += name[4] - '0'; + pos++; + } + + if (rid >= RID_NUM_FPR) { + return -1; + } + rid += RID_MIN_FPR; + } else { + return -1; + } + + if (pos < len) { + if (name[pos] == 'f') { + kind = REGKIND_FPR32; + pos++; + } else if (name[pos] == 'v') { + kind = REGKIND_V128; + pos++; + } else { + kind = REGKIND_FPR64; + } + } + + if (pos < len) { + return -1; + } + + return reg_make(rid, kind); +} + +int lj_intrinsic_getreg(CTState *cts, GCstr *name) { + + if (strdata(name)[0] == 'x') { + return parse_fprreg(strdata(name), name->len); + } else { + cTValue *reginfotv = lj_tab_getstr(cts->miscmap, name); + + if (reginfotv && !tvisnil(reginfotv)) { + return (uint32_t)(uintptr_t)lightudV(reginfotv); + } + } + + return -1; +} + +static CType *setarg_casttype(CTState *cts, CType *ctarg, CType *ct) { + CTypeID id; + + if (ctype_isvector(ct->info)) { + CTypeID argid = ctype_typeid(cts, ctarg); + id = lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|ctype_typeid(cts, ct)), + CTSIZE_PTR); + ctarg = ctype_get(cts, argid); + } else { + id = ctype_typeid(cts, ct); + } + + ctarg->size |= id << 16; + return ctarg; +} + +enum IntrinsRegSet { + REGSET_IN, + REGSET_OUT, + REGSET_MOD, +}; + +/* Walks through either a Lua table(array) of register names or ctype linked list +** of typed parameters who's name will be the register for that specific parameter. +** The register names are converted into a register id\kind which are packed +** together into a uint8_t that is saved into one of the register lists of the +** CIntrinsic passed in. +*/ +static RegSet process_reglist(lua_State *L, CIntrinsic *intrins, int regsetid, + CTypeID liststart) +{ + CTState *cts = ctype_cts(L); + uint32_t i, count = 0; + RegSet rset = 0; + const char *listname; + uint8_t *regout = NULL; + CTypeID sib = liststart; + + if (regsetid == REGSET_IN) { + listname = "in"; + regout = intrins->in; + } else if(regsetid == REGSET_OUT) { + listname = "out"; + regout = intrins->out; + } else { + listname = "mod"; + } + + for (i = 1; sib; i++) { + CType *ctarg = ctype_get(cts, sib); + GCstr *str = strref(ctarg->name); + Reg r = 0; + int32_t reg = -1; + sib = ctarg->sib; + + if (i > LJ_INTRINS_MAXREG && regsetid != REGSET_MOD) { + lj_err_callerv(L, LJ_ERR_FFI_REGOV, listname, LJ_INTRINS_MAXREG); + } + + reg = lj_intrinsic_getreg(cts, str); + + if (reg < 0) { + /* Unrecognized register name */ + lj_err_callerv(L, LJ_ERR_FFI_BADREG, "invalid name", strdata(str), listname); + } + + /* Pack the register info into the ctype argument */ + ctarg->size = reg & 0xff; + setarg_casttype(cts, ctarg, ctype_rawchild(cts, ctarg)); + + r = reg_rid(reg); + + /* Check for duplicate registers in the list */ + if (rset_test(rset, r)) { + lj_err_callerv(L, LJ_ERR_FFI_BADREG, "duplicate", strdata(str), listname); + } + rset_set(rset, r); + + if (regsetid == REGSET_OUT && reg_isgpr(reg)) { + CType *ct = ctype_rawchild(cts, ctarg); + + if (ctype_ispointer(ct->info) && !LJ_64) { + reg = reg_make(r, REGKIND_GPR32CD); + } else if (ctype_isnum(ct->info) && (ct->info & CTF_UNSIGNED) && ct->size == 4) { + reg = reg_make(r, REGKIND_GPR32CD); + } + } + + intrins->flags |= reg & 0xff00; + + if (reg & REGFLAG_BLACKLIST) { + lj_err_callerv(L, LJ_ERR_FFI_BADREG, "blacklisted", strdata(str), listname); + } + + if (regout) { + regout[count++] = (uint8_t)reg; + } + } + + if (regsetid == REGSET_IN) { + intrins->insz = (uint8_t)count; + } else if (regsetid == REGSET_OUT) { + intrins->outsz = (uint8_t)count; + } + + return rset; +} + +static void setopcode(lua_State *L, CIntrinsic *intrins, uint32_t opcode) +{ + int len; + + if (opcode == 0) { + lj_err_callermsg(L, "bad opcode literal"); + } + +#if LJ_TARGET_X86ORX64 + if (opcode <= 0xff) { + len = 1; + } else if (opcode <= 0xffff) { + len = 2; + } else if (opcode <= 0xffffff) { + len = 3; + } else { + len = 4; + } + + opcode = lj_bswap(opcode); + if (len < 4) { + opcode |= (uint8_t)(int8_t)-(len+1); + } else { + lj_err_callermsg(L, "bad opcode literal"); + } +#endif + + intrins->opcode = opcode; +} + +static int parse_opstr(lua_State *L, GCstr *opstr) +{ + const char *op = strdata(opstr); + uint32_t opcode = 0; + uint32_t i; + + /* Find the end of the opcode number */ + for (i = 0; i < opstr->len && lj_char_isxdigit((uint8_t)op[i]); i++) { + } + + if (i == 0 || i > 8) { + /* invalid or no hex number */ + lj_err_callerv(L, LJ_ERR_FFI_BADOPSTR, op, "invalid opcode number"); + } + + /* Scan hex digits. */ + for (; i; i--, op++) { + uint32_t d = *op; if (d > '9') d += 9; + opcode = (opcode << 4) + (d & 15); + } + + return opcode; +} + +extern int lj_asm_intrins(lua_State *L, IntrinWrapState *state); + +static IntrinsicWrapper lj_intrinsic_buildwrap(lua_State *L, CIntrinsic *intrins, + void* target, MSize targetsz, RegSet mod) +{ + IntrinWrapState state = { 0 }; + state.intrins = intrins; + state.target = target; + state.targetsz = targetsz; + state.mod = mod; + state.wrapper = 0; + + int err = lj_asm_intrins(L, &state); + + if (err != 0) { + const char* reason = "unknown error"; + + if (err == -(LJ_TRERR_BADRA+2)) { + reason = "too many live registers"; + } else if (err == -(LJ_TRERR_MCODEOV+2)) { + reason = "code too large for mcode area"; + } else if(err == -1 && tvisstr(L->top-1)) { + reason = strVdata(L->top-1); + } + + lj_err_callerv(L, LJ_ERR_FFI_INTRWRAP, reason); + } + + return (IntrinsicWrapper)state.wrapper; +} + +GCcdata *lj_intrinsic_createffi(CTState *cts, CType *func) +{ + GCcdata *cd; + CIntrinsic *intrins = lj_intrinsic_get(cts, func->size); + CTypeID id = ctype_typeid(cts, func); + RegSet mod = intrin_getmodrset(cts, intrins); + uint32_t op = intrins->opcode; + void* mcode = ((char*)&op) + (4-intrin_oplen(intrins)); + + intrins->wrapped = lj_intrinsic_buildwrap(cts->L, intrins, mcode, + intrin_oplen(intrins), mod); + + cd = lj_cdata_new(cts, id, CTSIZE_PTR); + *(void **)cdataptr(cd) = intrins->wrapped; + return cd; +} + +int lj_intrinsic_fromcdef(lua_State *L, CTypeID fid, GCstr *opstr, uint32_t imm) +{ + CTState *cts = ctype_cts(L); + CType *func = ctype_get(cts, fid); + CTypeID sib = func->sib, retid = ctype_cid(func->info); + uint32_t opcode; + CIntrinsic _intrins; + CIntrinsic* intrins = &_intrins; + memset(intrins, 0, sizeof(CIntrinsic)); + + opcode = parse_opstr(L, opstr); + + if (!opcode) { + return 0; + } + + if (sib) { + process_reglist(L, intrins, REGSET_IN, sib); + } + + + if (retid != CTID_VOID) { + CType *ct = ctype_get(cts, retid); + + /* Check if the intrinsic had __reglist declared on it */ + if (ctype_isfield(ct->info)) { + process_reglist(L, intrins, REGSET_OUT, retid); + sib = retid; + } + } else { + sib = retid; + } + + setopcode(L, intrins, opcode); + register_intrinsic(L, intrins, ctype_get(cts, fid)); + + lua_assert(sib > 0 && sib < cts->top); + return sib; +} + +/* Pre-create cdata for any output values that need boxing the wrapper will directly + * save the values into the cdata + */ +static void *setup_results(lua_State *L, CIntrinsic *intrins, CTypeID id) +{ + MSize i; + CTState *cts = ctype_cts(L); + CTypeID sib = 0; + void *outcontext = L->top; + + if (id == CTID_VOID) + return NULL; + + sib = id; + for (i = 0; i < intrins->outsz; i++) { + CType *ret = ctype_get(cts, sib); + CTypeID retid = ctype_cid(ret->info); + CType *ct = ctype_raw(cts, retid); + CTypeID rawid = ctype_typeid(cts, ct); + lua_assert(ctype_isfield(ret->info) && ctype_cid(ret->info)); + sib = ret->sib; + + /* Don't box what can be represented with a lua_number */ + if (rawid == CTID_INT32 || rawid == CTID_FLOAT || rawid == CTID_DOUBLE) + ct = NULL; + + if (ct) { + GCcdata *cd; + if (!(ct->info & CTF_VLA) && ctype_align(ct->info) <= CT_MEMALIGN) + cd = lj_cdata_new(cts, retid, ct->size); + else + cd = lj_cdata_newv(L, retid, ct->size, ctype_align(ct->info)); + + setcdataV(L, L->top++, cd); + } else { + L->top++; + } + } + + return outcontext; +} + +int lj_intrinsic_call(CTState *cts, CType *ct) +{ + lua_State *L = cts->L; + CIntrinsic *intrins = lj_intrinsic_get(cts, ct->size); + CTypeID fid, funcid = ctype_typeid(cts, ct); + TValue *o; + MSize ngpr = 0, nfpr = 0, narg; + void* outcontent = L->top; + uint32_t reg = 0; + RegContext context; + memset(&context, 0, sizeof(RegContext)); + + /* Skip initial attributes. */ + fid = ct->sib; + while (fid) { + CType *ctf = ctype_get(cts, fid); + if (!ctype_isattrib(ctf->info)) break; + fid = ctf->sib; + } + + narg = (MSize)((L->top-L->base)-1); + + /* Check for wrong number of arguments passed in. */ + if (narg < intrins->insz || narg > intrins->insz) { + lj_err_caller(L, LJ_ERR_FFI_NUMARG); + } + + /* Walk through all passed arguments. */ + for (o = L->base+1, narg = 0; narg < intrins->insz; o++, narg++) { + CType *ctf = ctype_get(cts, fid); + CType *d = ctype_get(cts, ctf->size >> 16); /* Use saved raw type we want to cast to */ + void *dp; + fid = ctf->sib; + lua_assert(ctype_isfield(ctf->info)); + + reg = ctf->size & 0xff; + + /* nil only makes sense for gpr based ptr arguments */ + if (tvisnil(o) && (!reg_isgpr(reg) || !ctype_isptr(d->info))) { + lj_err_arg(L, narg+1, LJ_ERR_NOVAL); + } + + if (reg_isgpr(reg)) { + lua_assert((ctype_isnum(d->info) && d->size <= 8) || ctype_isptr(d->info)); + dp = &context.gpr[ngpr++]; + } else { + lua_assert(reg_isvec(reg) || (ctype_isnum(d->info) && (d->info & CTF_FP))); + dp = &context.fpr[nfpr++]; + } + + lj_cconv_ct_tv(cts, d, (uint8_t *)dp, o, CCF_ARG(narg+1)); + } + + /* Pass in the return type chain so the results are typed */ + outcontent = setup_results(L, intrins, ctype_cid(ctype_get(cts, funcid)->info)); + + /* Execute the intrinsic through the wrapper created on first lookup */ + return (*(IntrinsicWrapper*)cdataptr(cdataV(L->base)))(&context, outcontent); +} + +void lj_intrinsic_init(lua_State *L) +{ + uint32_t i, count = (uint32_t)(sizeof(reglut)/sizeof(RegEntry)); + GCtab *t = ctype_cts(L)->miscmap; + + /* Build register name lookup table */ + for (i = 0; i < count; i++) { + TValue *slot = lj_tab_setstr(L, t, lj_str_newz(L, reglut[i].name)); + setlightudV(slot, (void*)(uintptr_t)reglut[i].slot); + } +} + +#else + +int lj_intrinsic_call(CTState *cts, CType *ct) +{ + UNUSED(cts); UNUSED(ct); + return 0; +} + +void lj_intrinsic_init(lua_State *L) +{ +} +#endif + + + + diff --git a/src/lj_intrinsic.h b/src/lj_intrinsic.h new file mode 100644 index 00000000..e53183b4 --- /dev/null +++ b/src/lj_intrinsic.h @@ -0,0 +1,113 @@ +/* +** FFI Intrinsic system. +*/ + +#ifndef _LJ_INTRINSIC_H +#define _LJ_INTRINSIC_H + +#include "lj_arch.h" +#include "lj_obj.h" +#include "lj_clib.h" +#include "lj_ctype.h" + +#if !defined(LJ_INTRINS_MAXREG) || LJ_INTRINS_MAXREG < 8 +#define LJ_INTRINS_MAXREG 8 +#endif + +typedef struct LJ_ALIGN(16) RegContext { + intptr_t gpr[LJ_INTRINS_MAXREG]; + double fpr[LJ_INTRINS_MAXREG]; +} RegContext; + +typedef enum INTRINSFLAGS { + INTRINSFLAG_MEMORYSIDE = 0x08, /* has memory side effects so needs an IR memory barrier */ + + /* Intrinsic should be emitted as a naked function that is called */ + INTRINSFLAG_CALLED = 0x20, + /* MODRM should always be set as indirect mode */ + INTRINSFLAG_INDIRECT = 0x40, + + INTRINSFLAG_CALLEDIND = INTRINSFLAG_CALLED | INTRINSFLAG_INDIRECT +} INTRINSFLAGS; + +typedef struct AsmHeader { + union{ + uintptr_t target; + struct { + uint16_t asmsz; + uint16_t asmofs; + }; + }; + uint32_t totalzs; +} AsmHeader; + +#define intrin_oplen(intrins) ((-(int8_t)(intrins)->opcode)-1) +#define intrin_getmodrset(cts, intrins) \ + ((ctype_get(cts, (intrins)->id)->size >> 16) ? \ + ctype_get(cts, ctype_get(cts, (intrins)->id)->size >> 16)->size : 0) + +#define RKDEF_FPR(_) \ + _(FPR64, IRT_NUM, CTID_DOUBLE) \ + _(FPR32, IRT_FLOAT, CTID_FLOAT) \ + _(V128, 0, 0) \ + _(FPR5, 0, 0) \ + _(FPR6, 0, 0) \ + _(FPR7, 0, 0) \ + +#define RKDEF_GPR(_) \ + _(GPRI32, IRT_INT, CTID_INT32) \ + _(GPR32CD, IRT_U32, CTID_UINT32) \ + _(GPR64, IRT_U64, CTID_UINT64) \ + _(GPR3, 0, 0) \ + _(GPR4, 0, 0) \ + _(GPR5, 0, 0) \ + _(GPR6, 0, 0) \ + _(GPR7, 0, 0) \ + +#define MKREGKIND(name, irt, ct) REGKIND_##name, + +typedef enum REGKINDGPR { + RKDEF_GPR(MKREGKIND) +} REGKINDGPR; + +typedef enum REGKINDFPR { + RKDEF_FPR(MKREGKIND) + REGKIND_VEC_START = REGKIND_V128, +} REGKINDFPR; + +uint8_t regkind_it[16]; +CTypeID1 regkind_ct[16]; + +#define reg_rid(r) ((r)&63) +#define reg_kind(r) (((r) >> 6) & 3) +#define reg_make(r, kind) ((r) | (kind << 6)) +#define reg_setrid(reg, rid) (((reg)&0xc0) | reg_rid(rid)) +#define reg_isgpr(reg) (reg_rid(reg) < RID_MAX_GPR) +#define reg_isfp(reg) (reg_rid(reg) >= RID_MIN_FPR) +#define reg_isvec(reg) (reg_rid(reg) >= RID_MIN_FPR && reg_kind(reg) >= REGKIND_VEC_START) + +#define reg_irt(reg) (reg_isgpr(reg) ? rk_irtgpr(reg_kind(reg)) : rk_irtfpr(reg_kind(reg))) +#define rk_irtgpr(kind) ((IRType)regkind_it[(kind)]) +#define rk_irtfpr(kind) ((IRType)regkind_it[(kind)+8]) +#define rk_irt(rid, kind) ((rid) < RID_MAX_GPR ? rk_irtgpr(kind) : rk_irtfpr(kind)) +#define rk_isvec(kind) ((kind) >= REGKIND_VEC_START) + +#define rk_ctypegpr(kind) (regkind_ct[(kind)]) +#define rk_ctypefpr(kind) (regkind_ct[(kind)+8]) +#define rk_ctype(rid, kind) ((rid) < RID_MAX_GPR ? rk_ctypegpr(kind) : rk_ctypefpr(kind)) + +LJ_FUNC void lj_intrinsic_init(lua_State *L); +LJ_FUNC GCcdata *lj_intrinsic_createffi(CTState *cts, CType *func); +LJ_FUNC int lj_intrinsic_fromcdef(lua_State *L, CTypeID fid, GCstr *opcode, uint32_t imm); +LJ_FUNC int lj_intrinsic_call(CTState *cts, CType *ct); +int lj_intrinsic_getreg(CTState *cts, GCstr *name); + +#define LJ_INTRINS_MAXID 0x1fff + +static LJ_AINLINE CIntrinsic *lj_intrinsic_get(CTState *cts, CTSize id) +{ + lua_assert((id & LJ_INTRINS_MAXID) < cts->intr.top); + return cts->intr.tab + (id & LJ_INTRINS_MAXID); +} +#endif + diff --git a/src/lj_jit.h b/src/lj_jit.h index ec431a61..f46775c0 100644 --- a/src/lj_jit.h +++ b/src/lj_jit.h @@ -484,6 +484,9 @@ typedef struct jit_State { BCIns patchins; /* Instruction for pending re-patch. */ MCodeArea mcarea; /* JIT mcode area */ +#if LJ_HASINTRINSICS + MCodeArea mcarea_intrins; /* Intrinsic mcode area used for interpreter wrappers */ +#endif MCodeArea *curmcarea; /* Current mcode area by default is mcarea */ TValue errinfo; /* Additional info element for trace errors. */ diff --git a/src/lj_target_x86.h b/src/lj_target_x86.h index 356f7924..75fce4cd 100644 --- a/src/lj_target_x86.h +++ b/src/lj_target_x86.h @@ -60,6 +60,20 @@ enum { RID_MAX_FPR = RID_MAX, RID_NUM_GPR = RID_MAX_GPR - RID_MIN_GPR, RID_NUM_FPR = RID_MAX_FPR - RID_MIN_FPR, + +#if LJ_64 +#if LJ_ABI_WIN + RID_CONTEXT = RID_ECX, + RID_OUTCONTEXT = RID_EDX, +#else + RID_CONTEXT = RID_EDI, + RID_OUTCONTEXT = RID_ESI, +#endif +#else + /* Fast call arguments */ + RID_CONTEXT = RID_ECX, + RID_OUTCONTEXT = RID_EDX, +#endif }; /* -- Register sets ------------------------------------------------------- */ @@ -181,6 +195,14 @@ typedef struct { uint8_t scale; /* Index scale (XM_SCALE1 .. XM_SCALE8). */ } x86ModRM; +typedef struct IntrinWrapState { + struct CIntrinsic *intrins; + RegSet mod; + void* target; + MSize targetsz; + void* wrapper; +}IntrinWrapState; + /* -- Opcodes ------------------------------------------------------------- */ /* Macros to construct variable-length x86 opcodes. -(len+1) is in LSB. */ @@ -210,6 +232,7 @@ typedef enum { XI_JMP = 0xe9, XI_JMPs = 0xeb, XI_PUSH = 0x50, /* Really 50+r. */ + XI_POP = 0x58, /* Really 50+r. */ XI_JCCs = 0x70, /* Really 7x. */ XI_JCCn = 0x80, /* Really 0f8x. */ XI_LEA = 0x8d, @@ -222,6 +245,7 @@ typedef enum { XI_TESTb = 0x84, XI_TEST = 0x85, XI_INT3 = 0xcc, + XI_RET = 0xC3, XI_MOVmi = 0xc7, XI_GROUP5 = 0xff, @@ -287,6 +311,9 @@ typedef enum { XO_MOVSSto = XO_f30f(11), XO_MOVLPD = XO_660f(12), XO_MOVAPS = XO_0f(28), + XO_MOVAPSto = XO_0f(29), + XO_MOVUPS = XO_0f(10), + XO_MOVUPSto = XO_0f(11), XO_XORPS = XO_0f(57), XO_ANDPS = XO_0f(54), XO_ADDSD = XO_f20f(58), diff --git a/src/lj_trace.c b/src/lj_trace.c index 525eacdd..80ab6903 100644 --- a/src/lj_trace.c +++ b/src/lj_trace.c @@ -357,6 +357,9 @@ void lj_trace_freestate(global_State *g) } #endif lj_mcode_free(J, &J->mcarea); +#if LJ_HASINTRINSICS + lj_mcode_free(J, &J->mcarea_intrins); +#endif lj_mem_freevec(g, J->snapmapbuf, J->sizesnapmap, SnapEntry); lj_mem_freevec(g, J->snapbuf, J->sizesnap, SnapShot); lj_mem_freevec(g, J->irbuf + J->irbotlim, J->irtoplim - J->irbotlim, IRIns); diff --git a/tests/debug.sh b/tests/debug.sh new file mode 100644 index 00000000..30a60f93 --- /dev/null +++ b/tests/debug.sh @@ -0,0 +1,7 @@ +#!/bin/bash +cd "${0%/*}" +cd .. +make CCDEBUG="-g" CCOPT=" -fomit-frame-pointer" +cd tests +export LUA_PATH="$PWD/?.lua;$PWD/../src/?.lua;$LUA_PATH" +gdb "${@:1}" \ No newline at end of file diff --git a/tests/intrinsic_spec.lua b/tests/intrinsic_spec.lua new file mode 100644 index 00000000..29b943fb --- /dev/null +++ b/tests/intrinsic_spec.lua @@ -0,0 +1,350 @@ +local ffi = require("ffi") +local jit = require("jit") + +ffi.cdef[[ +typedef float float4 __attribute__((__vector_size__(16))); +typedef float float8 __attribute__((__vector_size__(32))); +typedef int int4 __attribute__((__vector_size__(16))); +typedef uint8_t byte16 __attribute__((__vector_size__(16))); +]] + +local float4 = ffi.new("float[4]") +local float4_2 = ffi.new("float[4]", {2, 2, 2, 2}) +local float8 = ffi.new("float[8]", 0) +local byte16 = ffi.new("uint8_t[16]", 1, 0xff, 0) +local int4 = ffi.new("int32_t[5]", 0) +local float4ptr = float4+0 + +local union64 = ffi.new([[ +union __attribute__((packed, aligned(4))){ + int64_t i64; + struct{ + int32_t low; + int32_t high; + }; +}]]) + + +describe("intrinsic tests", function() + +context("nop inout", function() + + it("fpr", function() + assert_cdef([[void fpr_nop1(double xmm0) __mcode("90") __reglist(out, double xmm0)]], "fpr_nop1") + local fpr1 = ffi.C.fpr_nop1 + + assert_error(function() fpr1() end) + assert_error(function() fpr1(nil) end) + assert_error(function() fpr1(1, 2) end) + + assert_jit(123.075, function(num) return (fpr1(num)) end, 123.075) + assert_noexit(-123567.075, function(num) return (fpr1(num)) end, -123567.075) + end) + + it("gpr", function() + assert_cdef([[void gpr_nop1(int32_t eax) __mcode("90") __reglist(out, int32_t eax)]], "gpr_nop1") + + local function testgpr1(num) + return (ffi.C.gpr_nop1(num)) + end + + assert_jit(1235678, testgpr1, 1235678) + assert_noexit(-1, testgpr1, -1) + + assert_cdef([[void gpr_scatch(int32_t eax, int32_t ecx, int32_t edx) __mcode("90_E") + __reglist(out, int32_t eax, int32_t ecx, int32_t edx)]], "gpr_scatch") + + + local function testgpr_scratch(i, r1, r2, r3) + local ro1, ro2, ro3 = ffi.C.gpr_scatch(r1, r2, r3) + return ro1+i, ro2+i, ro3+i + end + + local function checker(i, ro1, ro2, ro3) + assert(ro1 == 0+i) + assert(ro2 == 1+i) + assert(ro3 == 30000+i) + end + + assert_jitchecker(checker, testgpr_scratch, 0, 1, 30000) + end) + +if ffi.arch == "x64" then + it("gpr64", function() + assert_cdef([[void gpr64_1(int64_t rdx) __mcode("90") __reglist(out, int64_t rdx)]], "gpr64_1") + + local function testgpr1(num) + return (ffi.C.gpr64_1(num)) + end + + assert_jit(1235678ull, testgpr1, 1235678) + assert_noexit(-1LL, testgpr1, -1) + end) + + it("rex fpr", function() + assert_cdef([[void fpr_reg(double xmm9, double xmm0) __mcode("90") __reglist(out, double xmm0, double xmm9)]], "fpr_reg") + local fpr = ffi.C.fpr_reg + + local function testrex(n1, n2) + local o1, o2 = fpr(n1, n2) + return o1+o2 + end + + assert_jit(444.575, testrex, 123.075, 321.5) + end) +end + + it("fpr_vec", function() + assert_cdef([[void fpr_vec(void* xmm7v) __mcode("90") __reglist(out, float4 xmm7v)]], "fpr_vec") + + local v1 = ffi.new("float[4]", 1, 2, 3, 4) + local xmmout = ffi.C.fpr_vec(v1) + assert_v4eq(xmmout, 1, 2, 3, 4) + end) + + it("idiv", function() + assert_cdef([[void idiv(int32_t eax, int32_t ecx) __mcode("99F7F9") __reglist(out, int32_t eax, int32_t edx)]], "idiv") + + local function checker(i, result, remainder) + local rem = i%3 + + if rem ~= remainder then + return rem, remainder + end + + local expected = (i-rem)/3 + + if expected ~= result then + return expected, result + end + end + + local function test_idiv(value, divisor) + local result, remainder = ffi.C.idiv(value, divisor) + return result, remainder + end + + assert_jitchecker(checker, test_idiv, 3) + + --test with jited with a constant arg + local function test_idivK(value) + local result, remainder = ffi.C.idiv(value, 3) + return result, remainder + end + + assert_jitchecker(checker, test_idivK, 3) + end) +end) + + +context("__mcode", function() + + it("incomplete mcode def", function() + assert_cdeferr([[int test1() __mcode]]) + assert_cdeferr([[int test2() __mcode(]]) + assert_cdeferr([[int test3() __mcode()]]) + assert_cdeferr([[int test3() __mcode(,)]]) + assert_cdeferr([[int test4() __mcode("ff"]]) + assert_cdeferr([[int test5() __mcode("ff",,)]]) + assert_cdeferr([[int test6() __mcode("ff" 1)]]) + assert_cdeferr([[int test7() __mcode("ff", )]]) + assert_cdeferr([[int test8() __mcode("ff", 1]]) + assert_cdeferr([[int test9() __mcode("ff", 1, 1]]) + assert_cdeferr([[int test10() __mcode("ff", 1, 1, ]]) + + assert_cdeferr([[__mcode("90")]]) + assert_cdeferr([[int __mcode("90")]]) + end) + + it("bad mcoddef", function() + assert_cdeferr([[void test1(float a) __mcode(0);]]) + assert_cdeferr([[void test2(float a) __mcode("");]]) + assert_cdeferr([[void test3(float a) __mcode("0");]]) + assert_cdeferr([[void test4(float a) __mcode("rff");]]) + assert_cdeferr([[struct c{float a __mcode("90");};]]) + --Max 2 literals after the opcode string + assert_cdeferr([[int test11() __mcode("ff", 1, 1, 2)]]) + + assert_cdeferr([[struct b{float a; __mcode("90");};]]) + end) + + it("invalid registers", function() + assert_cdef([[void validreg_gpr(int eax) __mcode("90");]], "validreg_gpr") + + assert_cdeferr([[void badreg_1(int e) __mcode("90");]], "invalid") + assert_cdeferr([[void badreg_1(int r20d) __mcode("90");]], "invalid") + assert_cdeferr([[void badreg_gpr1() __mcode("90") __reglist(out, int e);]], "invalid") + assert_cdeferr([[void badreg_gpr2() __mcode("90") __reglist(mod, e);]], "invalid") + + assert_cdef([[void validreg_fpr(float xmm0) __mcode("90");]], "validreg_fpr") + + assert_cdeferr([[void badreg_fpr1(float x) __mcode("90");]], "invalid") + assert_cdeferr([[void badreg_fpr1(float xm) __mcode("90");]], "invalid") + assert_cdeferr([[void badreg_fpr1(float xm0) __mcode("90");]], "invalid") + assert_cdeferr([[void badreg_fpr1(float xmmm0) __mcode("90");]], "invalid") + assert_cdeferr([[void badreg_fpr2(float xmm0vf) __mcode("90");]], "invalid") + --xmm register number too large + assert_cdeferr([[void badreg_fpr1(float xmm20) __mcode("90");]], "invalid") + end) + + it("multidef rollback", function() + + --check ctype rollback after parsing a valid intrinsic the line before + assert_cdeferr([[ + void multi1() __mcode("90"); + void multi2() __mcode("0"); + ]]) + + assert_error(function() ffi.C.multi1() end) + assert_error(function() ffi.C.multi2() end) + + assert_not_error(function() ffi.cdef[[ + void multi1(int32_t eax) __mcode("90") __reglist(out, int32_t eax); + ]] end) + + assert_equal(ffi.C.multi1(1.1), 1) + end) + + it("bad ffi types mcode", function() + assert_cdeferr([[void testffi1(float a2, ...) __mcode("90");]]) + assert_cdeferr([[void testffi2(complex a2) __mcode("90");]]) + + --NYI non 16/32 byte vectors + assert_cdeferr([[ + typedef float float2 __attribute__((__vector_size__(8))); + void testffi2(float2 a2) __mcode("90") + ]]) + end) + + it("bad args", function() + assert_cdef([[void idiv2(int32_t eax, int32_t ecx) __mcode("99F7F9") __reglist(out, int32_t eax, int32_t edx)]], "idiv2") + + local idiv = ffi.C.idiv2 + + assert_equal(idiv(6, 2), 3) + --too few arguments + assert_error(function() idiv() end) + assert_error(function() idiv(nil) end) + assert_error(function() idiv(1) end) + assert_error(function() idiv(1, nil) end) + + --too many arguments + assert_error(function() idiv(1, 2, nil) end) + assert_error(function() idiv(1, 2, 3) end) + assert_error(function() idiv(1, 2, 3, 4) end) + end) + + it("cpuid_brand", function() + assert_cdef([[void cpuid(int32_t eax, int32_t ecx) __mcode("0FA2") __reglist(out, int32_t eax, int32_t ebx, int32_t ecx, int32_t edx);]], "cpuid") + + local cpuid = ffi.C.cpuid + + local function getcpuidstr(eax) + int4[0] = 0; int4[1] = 0; int4[2] = 0; int4[3] = 0 + int4[0], int4[1], int4[2], int4[3] = cpuid(eax, 0) + return (ffi.string(ffi.cast("char*", int4+0))) + end + + local brand = getcpuidstr(-2147483646)..getcpuidstr(-2147483645)..getcpuidstr(-2147483644) + print("Processor brand: "..brand) + + local function testcpuid_brand() + local s = "" + + int4[0] = 0 + int4[1] = 0 + int4[2] = 0 + int4[3] = 0 + + int4[0], int4[1], int4[2], int4[3] = cpuid(-2147483646, 0) + s = s..ffi.string(ffi.cast("char*", int4+0)) + + int4[0], int4[1], int4[2], int4[3] = cpuid(-2147483645, 0) + s = s..ffi.string(ffi.cast("char*", int4+0)) + + int4[0], int4[1], int4[2], int4[3] = cpuid(-2147483644, 0) + s = s..ffi.string(ffi.cast("char*", int4+0)) + + return s + end + + assert_jit(brand, testcpuid_brand) + end) +end) + +context("__reglist", function() + + it("incomplete reglist", function() + assert_cdeferr([[int test1() __mcode("90") __reglist]]) + assert_cdeferr([[int test2() __mcode("90") __reglist(]]) + assert_cdeferr([[int test3() __mcode("90") __reglist();]]) + assert_cdeferr([[int test4() __mcode("90") __reglist(,);]]) + assert_cdeferr([[int test5() __mcode("90") __reglist(in, eax);]]) + assert_cdeferr([[int test6() __mcode("90") __reglist(out, ]]) + assert_cdeferr([[int test6() __mcode("90") __reglist(mod, ]]) + + assert_cdeferr([[int test7() __mcode("90") __reglist(mod, eax, ]]) + assert_cdeferr([[int test8() __mcode("90") __reglist("out, ]]) + assert_cdeferr([[int test9() __mcode("90") __reglist(o]]) + assert_cdeferr([[int test10() __mcode("90") __reglist(ou]]) + assert_cdeferr([[int invalid_reglist4() __mcode("90") __reglist(out, int)]]) + assert_cdeferr([[int invalid_reglist4() __mcode("90") __reglist(out, int eax,)]]) + end) + + it("invalid reglist", function() + assert_cdeferr([[int invalid_reglist1() __mcode("90") __reglist(inn, int eax)]]) + assert_cdeferr([[int invalid_reglist2() __mcode("90") __reglist(o, int eax)]]) + assert_cdeferr([[int invalid_reglist3() __mcode("90") __reglist(oout, int eax)]]) + assert_cdeferr([[int invalid_reglist4() __mcode("90") __reglist(out, int reax)]]) + + --exceeded max register list size + assert_cdeferr([[int invalid_reglist5() __mcode("90") __reglist(out, int eax, int ebx, + int ecx, int edx, int esi, int edi, float xmm0, float xmm1, float xmm2)]]) + end) + + it("stack pointer blacklist", function() + + assert_cdeferr([[void blacklist_in(int esp) __mcode("90")]], "blacklist") + assert_cdeferr([[void blacklist_out(int eax) __mcode("90") __reglist(out, int esp)]], "blacklist") + --FIXME + --assert_cdeferr([[void blacklist_mod(int eax) __mcode("90") __reglist(mod, esp)]], "blacklist") + + if ffi.arch == "x64" then + assert_cdeferr([[void blacklist_64(int rsp) __mcode("90")]], "blacklist") + end + end) + + it("duplicate regs", function() + assert_cdeferr([[void duplicate_in(int eax, int eax) __mcode("90")]], "duplicate") + assert_cdeferr([[void duplicate_inxmm(float4 xmm0, float4 xmm0) __mcode("90")]], "duplicate") + assert_cdeferr([[void duplicate_out(int eax) __mcode("90") __reglist(out, int eax, int eax)]], "duplicate") + --FIXME assert_cdeferr([[void duplicate_mod(int eax) __mcode("90_E") __reglist(mod, eax, eax)]], "duplicate") + end) + + it("rdtsc", function() + assert_cdef([[void rdtsc() __mcode("0f31") __reglist(out, int32_t eax, int32_t edx);]], "rdtsc") + + local rdtsc = ffi.C.rdtsc + + local function getticks() + union64.low, union64.high = rdtsc() + return union64.i64 + end + + local prev = 0ll + + local function checker(i, result) + --print(tonumber(result-prev)) + assert(result > prev) + + prev = result + end + + assert_jitchecker(checker, getticks) + end) +end) + + +end) + + diff --git a/tests/jit_tester.lua b/tests/jit_tester.lua new file mode 100644 index 00000000..8f4e884e --- /dev/null +++ b/tests/jit_tester.lua @@ -0,0 +1,399 @@ +local jit = require("jit") +local jutil = require("jit.util") +local vmdef = require("jit.vmdef") +local tracker = require("tracetracker") +local funcinfo, funcbc, traceinfo = jutil.funcinfo, jutil.funcbc, jutil.traceinfo +local traceinfo, traceir, tracek = jutil.traceinfo, jutil.traceir, jutil.tracek +local band = bit.band +local unpack = unpack +--set a high count for now to work around hot counter backoff +local testloopcount = 80 +local sub = string.sub + +local function fmtfunc(func, pc) + local fi = funcinfo(func, pc) + if fi.loc then + return fi.loc + elseif fi.ffid then + return vmdef.ffnames[fi.ffid] + elseif fi.addr then + return string.format("C:%x", fi.addr) + else + return "(?)" + end +end + +local bcnames = {} + + +for i=1,#vmdef.bcnames/6 do + bcnames[i] = string.sub(vmdef.bcnames, i+1, i+6) +end + +local irlookup = {} + +for i=0,(#vmdef.irnames)/6 do + local ir = sub(vmdef.irnames, (i*6)+1, (i+1) * 6) + ir = ir:match("^%s*(.-)%s*$") + irlookup[ir] = i +end + +local function getiroptype(tr, ref) + + local m, ot = traceir(tr, ref) + + return shr(ot, 8) +end + +function checkir(tr, ir, exop1, exop2, start) + + local info = traceinfo(tr) + assert(info) + + local findop = irlookup[ir] + assert(findop, "No IR opcode named "..tostring(ir)) + + local op1kind = (exop1 and irlookup[exop1]) or -1 + assert(op1kind == -1 or exop1 == nil, "op1 no IR opcode named "..tostring(exop1)) + + local op2kind = (exop2 and irlookup[exop2]) or -1 + assert(op2kind == -1 or exop2 == nil, "op2 no IR opcode named "..tostring(exop2)) + + start = start or 1 + local err = false + + for ins=start,info.nins do + local m, ot, op1, op2, ridsp = traceir(tr, ins) + local op, t = shr(ot, 8), band(ot, 31) + + if op == findop then + local match = true + + if op1kind ~= -1 and op1 ~= op1kind then + match = false + end + + + return ins + end + end + + trerror("Trace contained no %s IR opcodes", ir) +end + + +local expectedlnk = "return" + +local function trerror(s, a1, ...) + + tracker.print_savedevevents() + + if(a1) then + error(string.format(s, a1, ...), 4) + else + error(s, 4) + end + +end + +local function asserteq(result, expected, info) + if result ~= expected then + error(string.format("expected %q but got %q - %s", tostring(expected), tostring(result), info or "", 3)) + end +end + +local function checktrace(func, tr, mode) + + local traces = tracker.traces() + + if tr == nil then + + if #traces == 0 then + trerror("no traces were started for test") + else + --TODO: Filter out trace that are from checker code + tr = traces[1] + end + end + + if tr.abort then + trerror("trace aborted with error %s at %s", abort, fmtfunc(tr.stopfunc, tr.stoppc)) + end + + local info = traceinfo(tr.traceno) + + if info.linktype == "stitch" and expectedlnk ~= "stitch" then + trerror("trace did not cover full function stitched at %s", fmtfunc(tr.stopfunc, tr.stoppc)) + end + + if tr.startfunc ~= func then + trerror("trace did not start in tested function. started in %s", fmtfunc(tr.startfunc, tr.startpc)) + end + + if tr.stopfunc ~= func then + trerror("trace did not stop in tested function. stoped in %s", fmtfunc(tr.stopfunc, tr.stoppc)) + end + + if info.linktype ~= expectedlnk then + trerror("expect trace link '%s but got %s", expectedlnk, info.linktype) + end + + if mode == "root" then + if tracker.hasexits() then + trerror("unexpect traces exits ") + end + + if #traces > 1 then + trerror("unexpect extra traces were started for test ") + end + end +end + +local started_tracker = false + +local function begintest(func) + if not started_tracker then + tracker.start() + started_tracker = true + end + + jit.flush() + jit.on(func, true) --clear any interpreter only function/loop headers that may have been caused by other tests + tracker.clear() +end + +local function trerror2(s, a1, ...) + + tracker.print_savedevevents() + + if(a1) then + error(string.format(s, a1, ...), 3) + else + error(s, 3) + end +end + +function testsingle(expected, func, ...) + + begintest(func) + + local expectedval = expected + + for i=1, testloopcount do + + local result + + if type(expected) == "function" then + expectedval, result = expected(i, func(...)) + else + result = func(...) + end + + if (result ~= expectedval) then + local jitted, anyjited = tracker.isjited(func) + tracker.print_savedevevents() + trerror2("expected %q but got %q - %s", tostring(expected), tostring(result), (jitted and "JITed") or "Interpreted") + end + + end + + checktrace(func, tr, "root") + + return true +end + +function testwithchecker(checker, func, ...) + begintest(func) + + jit.off(checker) + + for i=1, testloopcount do + + local expected, result = checker(i, func(i, ...)) + + if (result ~= expected) then + local jitted, anyjited = tracker.isjited(func) + tracker.print_savedevevents() + trerror2("expected %q but got %q - %s", tostring(expected), tostring(result), (jitted and "JITed") or "Interpreted") + end + + end + + checktrace(func, nil, "root") + + return true +end + + +local state = { + WaitFirstTrace = 1, + CheckNoExits = 2, + RunNextConfig = 3, + CheckCompiledSideTrace = 4, +} + +--FIXME: the side traces that happen for config 2 will always abort because they trace out into this function which has jit turned off +local function testexits(func, config1, config2) + + begintest(func) + + local jitted = false + local trcount = 0 + local config, expected, shoulderror = config1, config1.expected, config1.shoulderror + local state = 1 + local sidestart = 0 + + for i=1, testloopcount do + local status, result + + if not shoulderror then + result = func(unpack(config.args)) + else + status, result = pcall(func, unpack(config.args)) + + if(status) then + tracker.print_savedevevents() + trerror2("expected call to trigger error but didn't "..tostring(i)) + end + end + + if state == 2 then + if tracker.hasexits() then + trerror2("trace exited on first run after being compiled "..expected) + end + state = 3 + end + + local newtraces = tracker.traceattemps() ~= trcount + + if newtraces then + trcount = tracker.traceattemps() + jitted, anyjited = tracker.isjited(func) + + if state == 1 then + --let the trace be executed once before we switch to the next arguments + state = 2 + elseif state == 4 and not tracker.traces()[trcount].abort then + state = 5 + sidestart = tracker.exitcount() + print("side trace compiled ".. tostring(expected)) + end + end + + if not shoulderror and result ~= expected then + tracker.print_savedevevents() + error(string.format("expected %q but got %q - %s", tostring(expected), tostring(result), (jitted and "JITed") or "Interpreted"), 2) + end + + if state == 3 then + config = config2 + expected = config2.expected + shoulderror = config2.shoulderror + state = 4 + end + + end + + local traces = tracker.traces() + + if #traces == 0 then + trerror2("no traces were started for test "..expected) + end + + local tr = traces[1] + + checktrace(func, tr) + + if not tracker.hasexits() then + trerror2("Expect trace to exit to interpreter") + end + + if sidestart ~= 0 and tracker.exitcount() > sidestart then + trerror2("Unexpected exits from side trace") + end + + assert(state >= 4) + + return true +end + +local function texiterror(msg) + tracker.printexits() + error(msg, 4) +end + +local function testexit(expected, func, ...) + + tracker.clearexits() + local result = func(...) + + if not tracker.hasexits() then + texiterror("Expected trace to exit but didn't") + end + + asserteq(result, expected) + + return true +end + +local function testnoexit(expected, func, ...) + + tracker.clearexits() + local result = func(...) + + if tracker.hasexits() then + texiterror("Unexpected trace exits") + end + + asserteq(result, expected) + + return true +end + +local function testexiterr(func, ...) + + tracker.clearexits() + local status, result = pcall(func, ...) + + if not tracker.hasexits() then + texiterror("Expected trace to exit but didn't") + end + + if(status) then + texiterror("Expected call to trigger error but didn't ") + end + + return true +end + +jit.off(true, true) + +local function setasserteq(func) + jit.off(func) + + --force interpreter only func header bc + for i=1, 30 do + func(1, 1) + end + asserteq = func +end + +require("jit.opt").start("hotloop=2") +--force the loop and function header in testjit to abort and be patched +local dummyfunc = function() return "" end +for i=1,30 do + pcall(testsingle, "", dummyfunc, "") +end + +require("jit.opt").start("hotloop=6") + +return { + testsingle = testsingle, + testwithchecker = testwithchecker, + testexit = testexit, + testnoexit = testnoexit, + testexiterr = testexiterr, + testexits = testexits, + testloopcount = testloopcount, + setasserteq = setasserteq, +} \ No newline at end of file diff --git a/tests/runtests.bat b/tests/runtests.bat new file mode 100644 index 00000000..f1aafe9e --- /dev/null +++ b/tests/runtests.bat @@ -0,0 +1,15 @@ +ECHO off + +cd "%~dp0" + +SET src=..\src +CALL :normalise "%src%" + +set LUA_PATH=%~dp0?.lua;%src%/?.lua + +..\src\luajit.exe runtests.lua +pause + +:normalise +SET "src=%~f1" +GOTO :EOF \ No newline at end of file diff --git a/tests/runtests.lua b/tests/runtests.lua new file mode 100644 index 00000000..6b663a34 --- /dev/null +++ b/tests/runtests.lua @@ -0,0 +1,124 @@ +local tester = require("jit_tester") +local testjit = tester.testsingle +local telescope = require("telescope") +local ffi = require("ffi") +local C = ffi.C + +local function check(expect, func, ...) + local result = func(...) + assert(result == expect, tostring(result)) + return true +end + +telescope.make_assertion("jit", "", check) +telescope.make_assertion("exit", "", check) +telescope.make_assertion("noexit", "", check) + +telescope.make_assertion("jitchecker", "", function(checker, func, ...) + + local expected, value = checker(1, func(1, ...)) + assert(expected == value) + return true +end) + +telescope.make_assertion("cdef", "", function(cdef, name) + assert(not name or type(name) == "string") + ffi.cdef(cdef) + if name then assert(C[name]) end + return true +end) + +telescope.make_assertion("cdeferr", "expected cdef '%s' to error", function(cdef, msg) + local success, ret = pcall(ffi.cdef, cdef) + if success then return false end + if msg then + local found = string.find(ret, msg) + + if not found then + error(string.format('cdef error message did not containt string: \n "%s" \nerror was\n "%s"', msg, ret)) + end + end + return true +end) + +telescope.make_assertion("v4eq", "", function(v, x, y, z, w) + + if v[0] ~= x then + error(string.format("expected v[0] to equal %s was %s", tostring(x), tostring(v[0]))) + elseif v[1] ~= y then + error(string.format("expected v[1] to equal %s was %s", tostring(y), tostring(v[1]))) + elseif v[2] ~= z then + error(string.format("expected v[2] to equal %s was %s", tostring(z), tostring(v[2]))) + elseif v[3] ~= w then + error(string.format("expected v[3] to equal %s was %s", tostring(w), tostring(v[3]))) + end + + return true +end) + +filter = filter or "" + +local callbacks = {} + +local function printfail() + print(" Failed!") +end + +callbacks.err = printfail +callbacks.fail = printfail + +function callbacks.before(t) + print("running", t.name) +end + +local contexts = {} +local files = {"intrinsic_spec.lua"} + +for _, file in ipairs(files) do + telescope.load_contexts(file, contexts) +end + +local buffer = {} +local testfilter + +if filter then + + if(type(filter) == "table") then + testfilter = function(t) + for _,patten in ipairs(filter) do + if t.name:match(patten) then + return true + end + end + + return false + end + elseif(type(filter) == "number") then + local count = 0 + local reverse = filter < 0 + testfilter = function(t) + count = count+1 + if ((not reverse and count > filter) or (reverse and (count+filter) < 0)) then + return false + end + + return true + end + elseif(filter ~= "") then + testfilter = function(t) return t.name:match(filter) end + end +end + +local results = telescope.run(contexts, callbacks, testfilter) +local summary, data = telescope.summary_report(contexts, results) +table.insert(buffer, summary) +local report = telescope.error_report(contexts, results) + +if report then + table.insert(buffer, "") + table.insert(buffer, report) +end + +if #buffer > 0 then + print(table.concat(buffer, "\n")) +end \ No newline at end of file diff --git a/tests/runtests.sh b/tests/runtests.sh new file mode 100644 index 00000000..cf2243bb --- /dev/null +++ b/tests/runtests.sh @@ -0,0 +1,4 @@ +#!/bin/bash +cd "${0%/*}" +export LUA_PATH="$PWD/?.lua;$PWD/../src/?.lua;$LUA_PATH" +../src/luajit ./test.lua \ No newline at end of file diff --git a/tests/telescope.lua b/tests/telescope.lua new file mode 100644 index 00000000..2f27e6ad --- /dev/null +++ b/tests/telescope.lua @@ -0,0 +1,594 @@ +--- Telescope is a test library for Lua that allows for flexible, declarative +-- tests. The documentation produced here is intended largely for developers +-- working on Telescope. For information on using Telescope, please visit the +-- project homepage at: http://github.com/norman/telescope#readme. +-- @release 0.6 +-- @class module +-- @module 'telescope' +local _M = {} +local getfenv = _G.getfenv +local setfenv = _G.setfenv + + +local _VERSION = "0.6.0" + +--- The status codes that can be returned by an invoked test. These should not be overidden. +-- @name status_codes +-- @class table +-- @field err - This is returned when an invoked test results in an error +-- rather than a passed or failed assertion. +-- @field fail - This is returned when an invoked test contains one or more failing assertions. +-- @field pass - This is returned when all of a test's assertions pass. +-- @field pending - This is returned when a test does not have a corresponding function. +-- @field unassertive - This is returned when an invoked test does not produce +-- errors, but does not contain any assertions. +local status_codes = { + err = 2, + fail = 4, + pass = 8, + pending = 16, + unassertive = 32 +} + +--- Labels used to show the various status_codes as a single character. +-- These can be overidden if you wish. +-- @name status_labels +-- @class table +-- @see status_codes +-- @field status_codes.err 'E' +-- @field status_codes.fail 'F' +-- @field status_codes.pass 'P' +-- @field status_codes.pending '?' +-- @field status_codes.unassertive 'U' + +local status_labels = { + [status_codes.err] = 'E', + [status_codes.fail] = 'F', + [status_codes.pass] = 'P', + [status_codes.pending] = '?', + [status_codes.unassertive] = 'U' +} + +--- The default names for context blocks. It defaults to "context", "spec" and +-- "describe." +-- @name context_aliases +-- @class table +local context_aliases = {"context", "describe", "spec"} +--- The default names for test blocks. It defaults to "test," "it", "expect", +-- "they" and "should." +-- @name test_aliases +-- @class table +local test_aliases = {"test", "it", "expect", "should", "they"} + +--- The default names for "before" blocks. It defaults to "before" and "setup." +-- The function in the before block will be run before each sibling test function +-- or context. +-- @name before_aliases +-- @class table +local before_aliases = {"before", "setup"} + +--- The default names for "after" blocks. It defaults to "after" and "teardown." +-- The function in the after block will be run after each sibling test function +-- or context. +-- @name after_aliases +-- @class table +local after_aliases = {"after", "teardown"} + +-- Prefix to place before all assertion messages. Used by make_assertion(). +local assertion_message_prefix = "Assert failed: expected " + +--- The default assertions. +-- These are the assertions built into telescope. You can override them or +-- create your own custom assertions using make_assertion. +-- +-- @see make_assertion +-- @name assertions +-- @class table +local assertions = {} + +--- Create a custom assertion. +-- This creates an assertion along with a corresponding negative assertion. It +-- is used internally by telescope to create the default assertions. +-- @param name The base name of the assertion. +--

+-- The name will be used as the basis of the positive and negative assertions; +-- i.e., the name equal would be used to create the assertions +-- assert_equal and assert_not_equal. +--

+-- @param message The base message that will be shown. +--

+-- The assertion message is what is shown when the assertion fails. It will be +-- prefixed with the string in telescope.assertion_message_prefix. +-- The variables passed to telescope.make_assertion are interpolated +-- in the message string using string.format. When creating the +-- inverse assertion, the message is reused, with " to be " replaced +-- by " not to be ". Hence a recommended format is something like: +-- "%s to be similar to %s". +--

+-- @param func The assertion function itself. +--

+-- The assertion function can have any number of arguments. +--

+-- @usage make_assertion("equal", "%s to be equal to %s", function(a, b) +-- return a == b end) +-- @function make_assertion +local function make_assertion(name, message, func) + local num_vars = 0 + -- if the last vararg ends up nil, we'll need to pad the table with nils so + -- that string.format gets the number of args it expects + local format_message + if type(message) == "function" then + format_message = message + else + for _, _ in message:gmatch("%%s") do num_vars = num_vars + 1 end + format_message = function(message, ...) + local a = {} + local args = {...} + local nargs = select('#', ...) + if nargs > num_vars then + local userErrorMessage = args[num_vars+1] + if type(userErrorMessage) == "string" then + return(assertion_message_prefix .. userErrorMessage) + else + error(string.format('assert_%s expected %d arguments but got %d', name, num_vars, #args)) + end + end + for i = 1, nargs do a[i] = tostring(args[i]) end + for i = nargs+1, num_vars do a[i] = 'nil' end + return (assertion_message_prefix .. message):format(unpack(a)) + end + end + + assertions["assert_" .. name] = function(...) + if assertion_callback then assertion_callback(...) end + if not func(...) then + error({format_message(message, ...), debug.traceback()}) + end + end +end + +--- (local) Return a table with table t's values as keys and keys as values. +-- @param t The table. +local function invert_table(t) + local t2 = {} + for k, v in pairs(t) do t2[v] = k end + return t2 +end + +-- (local) Truncate a string "s" to length "len", optionally followed by the +-- string given in "after" if truncated; for example, truncate_string("hello +-- world", 3, "...") +-- @param s The string to truncate. +-- @param len The desired length. +-- @param after A string to append to s, if it is truncated. +local function truncate_string(s, len, after) + if #s <= len then + return s + else + local s = s:sub(1, len):gsub("%s*$", '') + if after then return s .. after else return s end + end +end + +--- (local) Filter a table's values by function. This function iterates over a +-- table , returning only the table entries that, when passed into function f, +-- yield a truthy value. +-- @param t The table over which to iterate. +-- @param f The filter function. +local function filter(t, f) + local a, b + return function() + repeat a, b = next(t, a) + if not b then return end + if f(a, b) then return a, b end + until not b + end +end + +--- (local) Finds the value in the contexts table indexed with i, and returns a table +-- of i's ancestor contexts. +-- @param i The index in the contexts table to get ancestors for. +-- @param contexts The table in which to find the ancestors. +local function ancestors(i, contexts) + if i == 0 then return end + local a = {} + local function func(j) + if contexts[j].parent == 0 then return nil end + table.insert(a, contexts[j].parent) + func(contexts[j].parent) + end + func(i) + return a +end + +make_assertion("blank", "'%s' to be blank", function(a) return a == '' or a == nil end) +make_assertion("empty", "'%s' to be an empty table", function(a) return not next(a) end) +make_assertion("equal", "'%s' to be equal to '%s'", function(a, b) return a == b end) +make_assertion("error", "result to be an error", function(f) return not pcall(f) end) +make_assertion("false", "'%s' to be false", function(a) return a == false end) +make_assertion("greater_than", "'%s' to be greater than '%s'", function(a, b) return a > b end) +make_assertion("gte", "'%s' to be greater than or equal to '%s'", function(a, b) return a >= b end) +make_assertion("less_than", "'%s' to be less than '%s'", function(a, b) return a < b end) +make_assertion("lte", "'%s' to be less than or equal to '%s'", function(a, b) return a <= b end) +make_assertion("match", "'%s' to be a match for %s", function(a, b) return (tostring(b)):match(a) end) +make_assertion("nil", "'%s' to be nil", function(a) return a == nil end) +make_assertion("true", "'%s' to be true", function(a) return a == true end) +make_assertion("type", "'%s' to be a %s", function(a, b) return type(a) == b end) + +make_assertion("not_blank", "'%s' not to be blank", function(a) return a ~= '' and a ~= nil end) +make_assertion("not_empty", "'%s' not to be an empty table", function(a) return not not next(a) end) +make_assertion("not_equal", "'%s' not to be equal to '%s'", function(a, b) return a ~= b end) +make_assertion("not_error", "result not to be an error", function(f) return not not pcall(f) end) +make_assertion("not_match", "'%s' not to be a match for %s", function(a, b) return not (tostring(b)):match(a) end) +make_assertion("not_nil", "'%s' not to be nil", function(a) return a ~= nil end) +make_assertion("not_type", "'%s' not to be a %s", function(a, b) return type(a) ~= b end) + +--- Build a contexts table from the test file or function given in target. +-- If the optional contexts table argument is provided, then the +-- resulting contexts will be added to it. +--

+-- The resulting contexts table's structure is as follows: +--

+-- +-- { +-- {parent = 0, name = "this is a context", context = true}, +-- {parent = 1, name = "this is a nested context", context = true}, +-- {parent = 2, name = "this is a test", test = function}, +-- {parent = 2, name = "this is another test", test = function}, +-- {parent = 0, name = "this is test outside any context", test = function}, +-- } +-- +-- @param contexts A optional table in which to collect the resulting contexts +-- and function. +-- @function load_contexts +local function load_contexts(target, contexts) + local env = {} + local current_index = 0 + local context_table = contexts or {} + + local function context_block(name, func) + table.insert(context_table, {parent = current_index, name = name, context = true}) + local previous_index = current_index + current_index = #context_table + func() + current_index = previous_index + end + + local function test_block(name, func) + local test_table = {name = name, parent = current_index, test = func or true} + if current_index ~= 0 then + test_table.context_name = context_table[current_index].name + else + test_table.context_name = 'top level' + end + table.insert(context_table, test_table) + end + + local function before_block(func) + context_table[current_index].before = func + end + + local function after_block(func) + context_table[current_index].after = func + end + + for _, v in ipairs(after_aliases) do env[v] = after_block end + for _, v in ipairs(before_aliases) do env[v] = before_block end + for _, v in ipairs(context_aliases) do env[v] = context_block end + for _, v in ipairs(test_aliases) do env[v] = test_block end + + -- Set these functions in the module's meta table to allow accessing + -- telescope's test and context functions without env tricks. This will + -- however add tests to a context table used inside the module, so multiple + -- test files will add tests to the same top-level context, which may or may + -- not be desired. + setmetatable(_M, {__index = env}) + + setmetatable(env, {__index = _G}) + + local func, err = type(target) == 'string' and assert(loadfile(target)) or target + if err then error(err) end + setfenv(func, env)() + return context_table +end + +-- in-place table reverse. +function table.reverse(t) + local len = #t+1 + for i=1, (len-1)/2 do + t[i], t[len-i] = t[len-i], t[i] + end +end + +--- Run all tests. +-- This function will exectute each function in the contexts table. +-- @param contexts The contexts created by load_contexts. +-- @param callbacks A table of callback functions to be invoked before or after +-- various test states. +--

+-- There is a callback for each test status_code, and callbacks to run +-- before or after each test invocation regardless of outcome. +--

+-- +--

+-- Callbacks can be used, for example, to drop into a debugger upon a failed +-- assertion or error, for profiling, or updating a GUI progress meter. +--

+-- @param test_filter A function to filter tests that match only conditions that you specify. +--

+-- For example, the folling would allow you to run only tests whose name matches a pattern: +--

+--

+-- +-- function(t) return t.name:match("%s* lexer") end +-- +--

+-- @return A table of result tables. Each result table has the following +-- fields: +-- +-- @see load_contexts +-- @see status_codes +-- @function run +local function run(contexts, callbacks, test_filter) + + local results = {} + local status_names = invert_table(status_codes) + local test_filter = test_filter or function(a) return a end + + -- Setup a new environment suitable for running a new test + local function newEnv() + local env = {} + + -- Make sure globals are accessible in the new environment + setmetatable(env, {__index = _G}) + + -- Setup all the assert functions in the new environment + for k, v in pairs(assertions) do + setfenv(v, env) + env[k] = v + end + + return env + end + + local env = newEnv() + + local function invoke_callback(name, test) + if not callbacks then return end + if type(callbacks[name]) == "table" then + for _, c in ipairs(callbacks[name]) do c(test) end + elseif callbacks[name] then + callbacks[name](test) + end + end + + local function invoke_test(func) + local assertions_invoked = 0 + env.assertion_callback = function() + assertions_invoked = assertions_invoked + 1 + end + setfenv(func, env) + local result, message = xpcall(func, debug.traceback) + if result and assertions_invoked > 0 then + return status_codes.pass, assertions_invoked, nil + elseif result then + return status_codes.unassertive, 0, nil + elseif type(message) == "table" then + return status_codes.fail, assertions_invoked, message + else + return status_codes.err, assertions_invoked, {message, debug.traceback()} + end + end + + for i, v in filter(contexts, function(i, v) return v.test and test_filter(v) end) do + env = newEnv() -- Setup a new environment for this test + + local ancestors = ancestors(i, contexts) + local context_name = 'Top level' + if contexts[i].parent ~= 0 then + context_name = contexts[contexts[i].parent].name + end + local result = { + assertions_invoked = 0, + name = contexts[i].name, + context = context_name, + test = i + } + table.sort(ancestors) + -- this "before" is the test callback passed into the runner + invoke_callback("before", result) + + -- run all the "before" blocks/functions + for _, a in ipairs(ancestors) do + if contexts[a].before then + setfenv(contexts[a].before, env) + contexts[a].before() + end + end + + -- check if it's a function because pending tests will just have "true" + if type(v.test) == "function" then + result.status_code, result.assertions_invoked, result.message = invoke_test(v.test) + invoke_callback(status_names[result.status_code], result) + else + result.status_code = status_codes.pending + invoke_callback("pending", result) + end + result.status_label = status_labels[result.status_code] + + -- Run all the "after" blocks/functions + table.reverse(ancestors) + for _, a in ipairs(ancestors) do + if contexts[a].after then + setfenv(contexts[a].after, env) + contexts[a].after() + end + end + + invoke_callback("after", result) + results[i] = result + end + + return results + +end + +--- Return a detailed report for each context, with the status of each test. +-- @param contexts The contexts returned by load_contexts. +-- @param results The results returned by run. +-- @function test_report +local function test_report(contexts, results) + + local buffer = {} + local leading_space = " " + local level = 0 + local line_char = "-" + local previous_level = 0 + local status_format_len = 3 + local status_format = "[%s]" + local width = 72 + local context_name_format = "%-" .. width - status_format_len .. "s" + local function_name_format = "%-" .. width - status_format_len .. "s" + + local function space() + return leading_space:rep(level - 1) + end + + local function add_divider() + table.insert(buffer, line_char:rep(width)) + end + add_divider() + for i, item in ipairs(contexts) do + local ancestors = ancestors(i, contexts) + previous_level = level or 0 + level = #ancestors + -- the 4 here is the length of "..." plus one space of padding + local name = truncate_string(item.name, width - status_format_len - 4 - #ancestors, '...') + if previous_level ~= level and level == 0 then add_divider() end + if item.context then + table.insert(buffer, context_name_format:format(space() .. name .. ':')) + elseif results[i] then + table.insert(buffer, function_name_format:format(space() .. name) .. + status_format:format(results[i].status_label)) + end + end + add_divider() + return table.concat(buffer, "\n") + +end + +--- Return a table of stack traces for tests which produced a failure or an error. +-- @param contexts The contexts returned by load_contexts. +-- @param results The results returned by run. +-- @function error_report +local function error_report(contexts, results) + local buffer = {} + for _, r in filter(results, function(i, r) return r.message end) do + local name = contexts[r.test].name + table.insert(buffer, name .. ":\n" .. r.message[1] .. "\n" .. r.message[2]) + end + if #buffer > 0 then return table.concat(buffer, "\n") end +end + +--- Get a one-line report and a summary table with the status counts. The +-- counts given are: total tests, assertions, passed tests, failed tests, +-- pending tests, and tests which didn't assert anything. +-- @return A report that can be printed +-- @return A table with the various counts. Its fields are: +-- assertions, errors, failed, passed, +-- pending, tests, unassertive. +-- @param contexts The contexts returned by load_contexts. +-- @param results The results returned by run. +-- @function summary_report +local function summary_report(contexts, results) + local r = { + assertions = 0, + errors = 0, + failed = 0, + passed = 0, + pending = 0, + tests = 0, + unassertive = 0 + } + for _, v in pairs(results) do + r.tests = r.tests + 1 + r.assertions = r.assertions + v.assertions_invoked + if v.status_code == status_codes.err then r.errors = r.errors + 1 + elseif v.status_code == status_codes.fail then r.failed = r.failed + 1 + elseif v.status_code == status_codes.pass then r.passed = r.passed + 1 + elseif v.status_code == status_codes.pending then r.pending = r.pending + 1 + elseif v.status_code == status_codes.unassertive then r.unassertive = r.unassertive + 1 + end + end + local buffer = {} + for _, k in ipairs({"tests", "passed", "assertions", "failed", "errors", "unassertive", "pending"}) do + local number = r[k] + local label = k + if number == 1 then + label = label:gsub("s$", "") + end + table.insert(buffer, ("%d %s"):format(number, label)) + end + return table.concat(buffer, " "), r +end + +_M.after_aliases = after_aliases +_M.make_assertion = make_assertion +_M.assertion_message_prefix = assertion_message_prefix +_M.before_aliases = before_aliases +_M.context_aliases = context_aliases +_M.error_report = error_report +_M.load_contexts = load_contexts +_M.run = run +_M.test_report = test_report +_M.status_codes = status_codes +_M.status_labels = status_labels +_M.summary_report = summary_report +_M.test_aliases = test_aliases +_M.version = _VERSION +_M._VERSION = _VERSION + +return _M diff --git a/tests/tracetracker.lua b/tests/tracetracker.lua new file mode 100644 index 00000000..68be0380 --- /dev/null +++ b/tests/tracetracker.lua @@ -0,0 +1,224 @@ +local jit = require("jit") +local jutil = require("jit.util") +local vmdef = require("jit.vmdef") +local funcinfo, funcbc, traceinfo = jutil.funcinfo, jutil.funcbc, jutil.traceinfo +local band = bit.band + +local lib +local traces, texits = {}, {} + +local printevents = false + +local function fmtfunc(func, pc) + local fi = funcinfo(func, pc) + if fi.loc then + return fi.loc + elseif fi.ffid then + return vmdef.ffnames[fi.ffid] + elseif fi.addr then + return string.format("C:%x", fi.addr) + else + return "(?)" + end +end + +-- Format trace error message. +local function fmterr(err, info) + if type(err) == "number" then + if type(info) == "function" then info = fmtfunc(info) end + err = string.format(vmdef.traceerr[err], info) + end + return err +end + + +local function print_trevent(tr, printall) + + if printall or not tr.stopfunc then + print(string.format("\n[TRACE(%d) start at %s]", tr.traceno, fmtfunc(tr.startfunc, tr.startpc))) + end + + if tr.abort then + print(string.format("[TRACE(%d) abort at %s, error = %s]", tr.traceno, fmtfunc(tr.stopfunc, tr.stoppc), tr.abort)) + elseif tr.stopfunc then + print(string.format("[TRACE(%d) stop at %s]", tr.traceno, fmtfunc(tr.stopfunc, tr.stoppc))) + end +end + +local fwdevents = {} + +local function trace_event(what, tr, func, pc, otr, oex) + + local trace + + if fwdevents.trace then + fwdevents.trace(what, tr, func, pc, otr, oex) + end + + if what == "flush" then + return + end + + if what == "start" then + trace = { + traceno = tr, + startfunc = func, + startpc = pc, + } + traces[#traces+1] = trace + elseif what == "abort" or what == "stop" then + trace = traces[#traces] + assert(trace and trace.traceno == tr) + + trace.stopfunc = func + trace.stoppc = pc + + if what == "abort" then + lib.aborts = lib.aborts+1 + trace.abort = fmterr(otr, oex) + end + else + assert(false, what) + end + + if printevents then + print_trevent(trace) + end +end + +local function trace_exit(tr, ex, ngpr, nfpr, ...) + + if fwdevents.texit then + fwdevents.texit(tr, ex, ngpr, nfpr, ...) + end + + texits[#texits+1] = { + tr = tr, + exitno = ex, + order = #traces + } + + if printevents then + print("---- TRACE ", tr, " exit ", ex) + end +end + +local function isjited(func, starti) + + local hasany = false + + starti = starti or 1 + + for i=starti,#traces do + + local tr = traces[i] + + if not tr.abort then + hasany = true + if tr.startfunc == func or tr.stopfunc == func then + return i, true + end + end + end + + return false,hasany +end + +function hastraces(optfunc, starti) + + if not optfunc then + return traces[1] ~= nil + end + + starti = starti or 1 + + for i=starti,#traces do + + if tr.startfunc == func or tr.stopfunc == func then + return i, true + end + if not tr.abort then + hasany = true + end + end +end + +local active = false + +local function start() + if not active then + active = true + jit.attach(trace_event, "trace") + jit.attach(trace_exit, "texit") + end +end + +local function stop() + if active then + active = false + jit.attach(trace_event) + jit.attach(trace_exit) + --jit.attach(dump_trace) + end +end + +local function set_vmevent_forwarding(fwdtbl) + assert(type(fwdtbl) == "table") + fwdevents = fwdtbl +end + +local function clear() + traces = {} + texits = {} + lib.aborts = 0 +end + +local function clearexits() + texits = {} +end + +local function print_savedevevents() + + local nextexit = texits[1] and texits[1].order + local exi = 1 + + for i,tr in ipairs(traces) do + print_trevent(tr, true) + + if nextexit and i >= nextexit then + for i=exi,#texits do + + local exit = texits[i] + + if exit.order > i then + exi = i + break + end + + print("---- TRACE ", exit.tr, " exit ", exit.exitno) + end + end + + end +end + +lib = { + start = start, + stop = stop, + clear = clear, + clearexits = clearexits, + set_vmevent_forwarding = set_vmevent_forwarding, + + isjited = isjited, + hasexits = function() return texits[1] ~= nil end, + setprintevents = function(enabled) printevents = enabled end, + traces = function() return traces end, + exitcount = function() return #texits end, + traceattemps = function() return #traces end, + print_savedevevents = print_savedevevents, + aborts = 0, +} + +jit.off(true, true) + +return lib \ No newline at end of file diff --git a/tests/tsc b/tests/tsc new file mode 100644 index 00000000..a9ec7531 --- /dev/null +++ b/tests/tsc @@ -0,0 +1,304 @@ +#!/usr/bin/env lua +local telescope = require 'telescope' + +pcall(require, "luarocks.require") +pcall(require, "shake") + +package.path = "./?.lua;" .. package.path + +local function luacov_report() + local luacov = require("luacov.stats") + local data = luacov.load_stats() + if not data then + print("Could not load stats file "..luacov.statsfile..".") + print("Run your Lua program with -lluacov and then rerun luacov.") + os.exit(1) + end + local report = io.open("coverage.html", "w") + report:write('', "\n") + report:write([[ + + + + Luacov Coverage Report + + + + +
+

Luacov Code Coverage Report

+ ]]) + report:write("

Generated on ", os.date(), "

\n") + + local names = {} + for filename, _ in pairs(data) do + table.insert(names, filename) + end + + local escapes = { + [">"] = ">", + ["<"] = "<" + } + local function escape_html(str) + return str:gsub("[<>]", function(a) return escapes[a] end) + end + + table.sort(names) + + for _, filename in ipairs(names) do + if string.match(filename, "/luacov/") or + string.match(filename, "/luarocks/") or + string.match(filename, "/tsc$") + then + break + end + local filedata = data[filename] + filename = string.gsub(filename, "^%./", "") + local file = io.open(filename, "r") + if file then + report:write("

", filename, "

", "\n") + report:write("
") + report:write("
    ", "\n") + local line_nr = 1 + while true do + local line = file:read("*l") + if not line then break end + if line:match("^%s*%-%-") then -- Comment line + + elseif line:match("^%s*$") -- Empty line + or line:match("^%s*end,?%s*$") -- Single "end" + or line:match("^%s*else%s*$") -- Single "else" + or line:match("^%s*{%s*$") -- Single opening brace + or line:match("^%s*}%s*$") -- Single closing brace + or line:match("^#!") -- Unix hash-bang magic line + then + report:write("
  • ", string.format("%-4d", line_nr), "      ", escape_html(line), "
  • ", "\n") + else + local hits = filedata[line_nr] + local class = "uncovered" + if not hits then hits = 0 end + if hits > 0 then class = "covered" end + report:write("
  • ", "
    ", string.format("%-4d", line_nr), string.format("%-4d", hits), " ", escape_html(line), "
  • ", "\n") + end + line_nr = line_nr + 1 + end + end + report:write("
", "\n") + report:write("
", "\n") + end + report:write([[ +
+ + + ]]) +end + +local function getopt(arg, options) + local tab = {} + for k, v in ipairs(arg) do + if string.sub(v, 1, 2) == "--" then + local x = string.find(v, "=", 1, true) + if x then tab[string.sub(v, 3, x - 1)] = string.sub(v, x + 1) + else tab[string.sub(v, 3)] = true + end + elseif string.sub(v, 1, 1) == "-" then + local y = 2 + local l = string.len(v) + local jopt + while (y <= l) do + jopt = string.sub(v, y, y) + if string.find(options, jopt, 1, true) then + if y < l then + tab[jopt] = string.sub(v, y + 1) + y = l + else + tab[jopt] = arg[k + 1] + end + else + tab[jopt] = true + end + y = y + 1 + end + end + end + return tab +end + +local callbacks = {} + +local function progress_meter(t) + io.stdout:write(t.status_label) +end + +local function show_usage() + local text = [[ +Telescope + +Usage: tsc [options] [files] + +Description: + Telescope is a test framework for Lua that allows you to write tests + and specs in a TDD or BDD style. + +Options: + + -f, --full Show full report + -q, --quiet Show don't show any stack traces + -s --silent Don't show any output + -h,-? --help Show this text + -v --version Show version + -c --luacov Output a coverage file using Luacov (http://luacov.luaforge.net/) + --load= Load a Lua file before executing command + --name= Only run tests whose name matches a Lua string pattern + --shake Use shake as the front-end for tests + + Callback options: + --after= Run function given after each test + --before= Run function before each test + --err= Run function after each test that produces an error + --fail Run function after each failing test + --pass= Run function after each passing test + --pending= Run function after each pending test + --unassertive= Run function after each unassertive test + + An example callback: + + tsc --after="function(t) print(t.status_label, t.name, t.context) end" example.lua + +An example test: + +context("A context", function() + before(function() end) + after(function() end) + context("A nested context", function() + test("A test", function() + assert_not_equal("ham", "cheese") + end) + context("Another nested context", function() + test("Another test", function() + assert_greater_than(2, 1) + end) + end) + end) + test("A test in the top-level context", function() + assert_equal(1, 1) + end) +end) + +Project home: + http://telescope.luaforge.net/ + +License: + MIT/X11 (Same as Lua) + +Author: + Norman Clarke . Please feel free to email bug + reports, feedback and feature requests. +]] + print(text) +end + +local function add_callback(callback, func) + if callbacks[callback] then + if type(callbacks[callback]) ~= "table" then + callbacks[callback] = {callbacks[callback]} + end + table.insert(callbacks[callback], func) + else + callbacks[callback] = func + end +end + +local function process_args() + local files = {} + local opts = getopt(arg, "") + local i = 1 + for _, _ in pairs(opts) do i = i+1 end + while i <= #arg do table.insert(files, arg[i]) ; i = i + 1 end + return opts, files +end +local opts, files = process_args() +if opts["h"] or opts["?"] or opts["help"] or not (next(opts) or next(files)) then + show_usage() + os.exit() +end + +if opts.v or opts.version then + print(telescope.version) + os.exit(0) +end + +if opts.c or opts.luacov then + require "luacov.tick" +end + +-- load a file with custom functionality if desired +if opts["load"] then dofile(opts["load"]) end + +local test_pattern +if opts["name"] then + test_pattern = function(t) return t.name:match(opts["name"]) end +end + +-- set callbacks passed on command line +local callback_args = { "after", "before", "err", "fail", "pass", + "pending", "unassertive" } +for _, callback in ipairs(callback_args) do + if opts[callback] then + add_callback(callback, loadstring(opts[callback])()) + end +end + +local contexts = {} +if opts["shake"] then + for _, file in ipairs(files) do shake.load_contexts(file, contexts) end +else + for _, file in ipairs(files) do telescope.load_contexts(file, contexts) end +end + +local buffer = {} +local results = telescope.run(contexts, callbacks, test_pattern) +local summary, data = telescope.summary_report(contexts, results) + +if opts.f or opts.full then + table.insert(buffer, telescope.test_report(contexts, results)) +end + +if not opts.s and not opts.silent then + table.insert(buffer, summary) + if not opts.q and not opts.quiet then + local report = telescope.error_report(contexts, results) + if report then + table.insert(buffer, "") + table.insert(buffer, report) + end + end +end + +if #buffer > 0 then print(table.concat(buffer, "\n")) end + +if opts.c or opts.coverage then + luacov_report() + os.remove("luacov.stats.out") +end + +for _, v in pairs(results) do + if v.status_code == telescope.status_codes.err or + v.status_code == telescope.status_codes.fail then + os.exit(1) + end +end