diff --git a/src/lj_asm.c b/src/lj_asm.c index 93f2eb2d..c01bf98e 100644 --- a/src/lj_asm.c +++ b/src/lj_asm.c @@ -2701,7 +2701,7 @@ static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *sta uint8_t *in = info.in, *out = info.out; int spadj = 0; int dynreg = intrin_regmode(intrins); - Reg rout = RID_NONE, rin = RID_NONE; + Reg rout = RID_NONE, rin = RID_NONE, r3 = RID_NONE; lj_asm_setup_intrins(J, as); origtop = as->mctop; @@ -2735,6 +2735,9 @@ static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *sta info.inset |= pickdynlist(in+inofs, intrins->dyninsz-inofs, scatch); } + if (dynreg == DYNREG_VEX3) + r3 = reg_rid(in[1]); + if (rin == RID_NONE) rin = reg_rid(in[0]); @@ -2746,6 +2749,9 @@ static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *sta } else if (dynreg == DYNREG_OPEXT) { /* Destructive single register opcode */ rout = out[0] = reg_setrid(out[0], rin); + /* Becomes non destructive in vex form*/ + if (intrins->flags & INTRINSFLAG_VEX) + r3 = reg_rid(rout); } else { scatch = RSET_INIT & ~info.outset; rset_clear(scatch, info.outcontext); @@ -2868,10 +2874,10 @@ restart: if (intrins->flags & INTRINSFLAG_CALLED) { /* emit a call to the target which may be collocated after us */ - emit_intrins(as, intrins, rin, (uintptr_t)target); + emit_intrins(as, intrins, rin, (uintptr_t)target, 0); } else if (dynreg) { /* Write an opcode to the wrapper */ - asmofs = emit_intrins(as, intrins, rin, rout); + asmofs = emit_intrins(as, intrins, rin, rout, r3); } else { /* Append the user supplied machine code */ asmofs = asm_mcode(as, state->target, state->targetsz); diff --git a/src/lj_asm_x86.h b/src/lj_asm_x86.h index b7ce4b06..de637c29 100644 --- a/src/lj_asm_x86.h +++ b/src/lj_asm_x86.h @@ -730,8 +730,10 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo) uint32_t dynreg = intrin_regmode(intrins); RegSet allow; IRRef lref = 0, rref = 0; - Reg right, dest = RID_NONE; + Reg right, dest = RID_NONE, vvvv = RID_NONE; int dynrout = intrins->outsz > 0 && intrin_dynrout(intrins); + int vex3 = dynreg == DYNREG_VEX3; + int vexop = intrins->flags & INTRINSFLAG_VEX; /* Swap to refs to native ordering */ if (dynreg >= DYNREG_SWAPREGS) { @@ -762,8 +764,12 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo) } dest = ra_dest(as, ir, allow); if (dynreg == DYNREG_OPEXT) { - /* Set input register the same as the output since the op is destructive */ - right = dest; + if (vexop) { + vvvv = dest; + } else { + /* Set input register the same as the output since the op is destructive */ + right = dest; + } } } @@ -822,7 +828,8 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo) /* Handle second input reg for any two input dynamic in register modes ** which isn't DYNREG_INOUT */ - if (intrins->dyninsz > 1 && ra_noreg(dest)) { + if (intrins->dyninsz > 1 && ((!vex3 && ra_noreg(dest)) || + (vex3 && ra_noreg(IR(args[1])->r)))) { Reg r; allow = reg_torset(in[1]) & ~ininfo->inset; if (ra_hasreg(right) && right != RID_MRM) @@ -830,7 +837,13 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo) r = ra_allocref(as, args[1], allow); in[1] = reg_setrid(in[1], r); - dest = r; + + if (!vex3) { + dest = r; + } else if (lref == rref) { + /* update right for same ref */ + right = r; + } } if (right == RID_MRM) { @@ -848,15 +861,19 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo) in[0] = reg_setrid(in[0], right); } + if (vexop && vex3) { + vvvv = reg_rid(in[1]); + } + lua_assert(ra_hasreg(right) && (ra_hasreg(dest) || intrins->dyninsz < 2)); - emit_intrins(as, intrins, right, dest); + emit_intrins(as, intrins, right, dest, vvvv); if (dynreg == DYNREG_INOUT) { lua_assert(lref); ra_left(as, dest, lref); /* no need to load the register since ra_left already did */ in[1] = 0xff; - } else if (dynreg == DYNREG_OPEXT && dynrout) { + } else if (dynreg == DYNREG_OPEXT && dynrout && !vexop) { /* Handle destructive ONEOPEXT opcodes */ lua_assert(rref); ra_left(as, dest, rref); @@ -1011,7 +1028,7 @@ static void asm_intrinsic(ASMState *as, IRIns *ir, IRIns *asmend) r1 = ra_scratch(as, RSET_GPR & ~(ininfo.inset | ininfo.outset)); } } - emit_intrins(as, intrins, r1, target); + emit_intrins(as, intrins, r1, target, 0); } asm_asmsetupargs(as, &ininfo); diff --git a/src/lj_emit_x86.h b/src/lj_emit_x86.h index 56342858..b03b0796 100644 --- a/src/lj_emit_x86.h +++ b/src/lj_emit_x86.h @@ -33,6 +33,9 @@ /* msb is also set to c5 so we can spot a vex op in op_emit */ #define VEX2 0xc5c5 +#define VEX_OP2(o, pp) ((uint32_t)(0xf8c5c5 | ((pp<<16) + (o<<24)))) +#define VEX_OP3(o, pp, mode) ((uint32_t)(0x78e0c4 | (mode << 8) | ((pp<<16) + (o<<24)))) + /* vvvv bits in the opcode are assumed to be set */ #define VEXOP_SETVVVV(o, rid) ((o) ^ (((rid < RID_MIN_FPR ? \ rid : (rid)-RID_MIN_FPR)) << 19)) @@ -40,6 +43,49 @@ /* extract and merge the opcode,vvv,L,pp, W and set VEXMAP_0F */ #define VEX2TO3(op) ((op & 0xff7f0000) | 0xe1c4 | ((op & 0x800000) >> 8)) +static int vexpp(uint32_t byte) +{ + switch (byte) { + case 0x66: + return VEXPP_66; + case 0xf3: + return VEXPP_f3; + case 0xf2: + return VEXPP_f2; + default: + return VEXPP_0f; + } +} + +static int vexmap(uint32_t byte) +{ + switch (byte & 0xffff) { + case 0x380F: + return VEXMAP_0F38; + case 0x3a0f: + return VEXMAP_0F3A; + default: + lua_assert((byte & 0xff) == 0x0f); + return VEXMAP_0F; + } +} + +uint32_t sse2vex(uint32_t op, uint32_t len, uint32_t vex_w) +{ + x86Op vo = op >> 24; + int32_t pp = vexpp((op >> ((4-len) * 8)) & 0xff); + uint32_t mode = vexmap(op >> (len == 4 ? 8 : 16)); + + if (!vex_w && (len == 2 || (len == 3 && mode == VEXMAP_0F))) { + vo = VEX_OP2(vo, pp); + } else { + vo = VEX_OP3(vo, pp, mode); + if(vex_w) + vo |= VEX_64; + } + return vo; +} + #define emit_i8(as, i) (*--as->mcp = (MCode)(i)) #define emit_i32(as, i) (*(int32_t *)(as->mcp-4) = (i), as->mcp -= 4) #define emit_u32(as, u) (*(uint32_t *)(as->mcp-4) = (u), as->mcp -= 4) @@ -653,7 +699,7 @@ static void emit_addptr(ASMState *as, Reg r, int32_t ofs) static MCode* emit_intrins(ASMState *as, CIntrinsic *intrins, Reg r1, - uintptr_t r2) + uintptr_t r2, Reg r3) { uint32_t regmode = intrin_regmode(intrins); if (regmode) { @@ -674,7 +720,15 @@ static MCode* emit_intrins(ASMState *as, CIntrinsic *intrins, Reg r1, r2 |= OP4B; } - emit_mrm(as, intrins->opcode, (Reg)r2, r1); + if (intrins->flags & INTRINSFLAG_VEX) { + x86Op op = intrins->opcode; + if (r3 != RID_NONE) { + op = VEXOP_SETVVVV(op, r3); + } + emit_mrm(as, op, (Reg)r2, r1); + } else { + emit_mrm(as, intrins->opcode, (Reg)r2, r1); + } if (intrins->flags & INTRINSFLAG_PREFIX) { *--as->mcp = intrins->prefix; diff --git a/src/lj_intrinsic.c b/src/lj_intrinsic.c index c140a708..c5ee7a94 100644 --- a/src/lj_intrinsic.c +++ b/src/lj_intrinsic.c @@ -333,6 +333,12 @@ static int parse_opmode(const char *op, MSize len) case 'E': flags |= INTRINSFLAG_EXPLICTREGS; break; + case 'V': + flags |= INTRINSFLAG_AVXREQ; + case 'v': + /* Use vex encoding of the op if avx/xv2 is supported */ + flags |= INTRINSFLAG_VEX; + break; default: /* return index of invalid flag */ @@ -588,7 +594,12 @@ GCcdata *lj_intrinsic_createffi(CTState *cts, CType *func) RegSet mod = intrin_getmodrset(cts, intrins); if (intrins->opcode == 0) { - lj_err_callermsg(cts->L, "expected non template intrinsic"); + if (intrin_regmode(intrins) == DYNREG_FIXED) { + lj_err_callermsg(cts->L, "expected non template intrinsic"); + } else { + /* Opcode gets set to 0 during parsing if the cpu feature missing */ + lj_err_callermsg(cts->L, "Intrinsic not support by cpu"); + } } /* Build the interpreter wrapper */ @@ -606,6 +617,8 @@ GCcdata *lj_intrinsic_createffi(CTState *cts, CType *func) return cd; } +extern uint32_t sse2vex(uint32_t op, uint32_t len, uint32_t vex_w); + int lj_intrinsic_fromcdef(lua_State *L, CTypeID fid, GCstr *opstr, uint32_t imm) { CTState *cts = ctype_cts(L); @@ -735,6 +748,39 @@ int lj_intrinsic_fromcdef(lua_State *L, CTypeID fid, GCstr *opstr, uint32_t imm) uint8_t temp = intrins->in[0]; intrins->in[0] = intrins->in[1]; intrins->in[1] = temp; } + + if (intrins->flags & INTRINSFLAG_VEX) { + int vex_w = 0; + + /* Set the VEX.W/E bit if the X flag is set */ + if (intrins->flags & INTRINSFLAG_REXW) { + vex_w = 1; + intrins->flags &= ~INTRINSFLAG_REXW; + } + + if (L2J(L)->flags & JIT_F_AVX1) { + intrins->opcode = sse2vex(intrins->opcode, intrin_oplen(intrins), vex_w); + intrins->flags &= ~INTRINSFLAG_LARGEOP; + /* Switch to non destructive source if the sse reg mode is destructive */ + if (intrin_regmode(intrins) == DYNREG_INOUT) { + intrin_setregmode(intrins, DYNREG_VEX3); + } + /* Set the VEX.L bit if the opcode has any 256 bit registers declared */ + if (intrins->flags & INTRINSFLAG_VEX256) { + intrins->opcode |= VEX_256; + } + } else if(buildflags & INTRINSFLAG_AVXREQ) { + /* Disable instantiation of the intrinsic since AVX is not support by CPU */ + intrins->opcode = 0; + } else { + /* Vex encoding is not optional with these flags */ + if ((intrins->flags & INTRINSFLAG_VEX256) || vex_w) { + return 0; + } + /* Use opcode unmodified in its SSE form */ + intrins->flags &= ~INTRINSFLAG_VEX; + } + } #endif if (intrins->flags & INTRINSFLAG_PREFIX) { diff --git a/src/lj_intrinsic.h b/src/lj_intrinsic.h index 438174c2..ae095718 100644 --- a/src/lj_intrinsic.h +++ b/src/lj_intrinsic.h @@ -34,6 +34,8 @@ typedef enum REGMODE { DYNREG_OPEXT, /* Two input register and one output same register that's same RID the second input */ DYNREG_INOUT, + /* 2 in, 1 out */ + DYNREG_VEX3, /* Two input registers with M dynamic output register */ DYNREG_TWOIN, @@ -44,7 +46,8 @@ typedef enum INTRINSFLAGS { INTRINSFLAG_REGMODEMASK = 7, INTRINSFLAG_MEMORYSIDE = 0x08, /* has memory side effects so needs an IR memory barrier */ - + /* Vex encoded opcode, vvvv may be unused though */ + INTRINSFLAG_VEX = 0x10, /* Intrinsic should be emitted as a naked function that is called */ INTRINSFLAG_CALLED = 0x20, /* MODRM should always be set as indirect mode */ @@ -72,6 +75,8 @@ typedef enum INTRINSFLAGS { ** user supplied code. */ INTRINSFLAG_TEMPLATE = 0x40000, + /* Opcode is only supported if the CPU supports AVX */ + INTRINSFLAG_AVXREQ = 0x80000, INTRINSFLAG_CALLEDIND = INTRINSFLAG_CALLED | INTRINSFLAG_INDIRECT } INTRINSFLAGS; diff --git a/src/lj_target_x86.h b/src/lj_target_x86.h index ffcd7411..906bbb80 100644 --- a/src/lj_target_x86.h +++ b/src/lj_target_x86.h @@ -236,6 +236,8 @@ typedef enum VEXMAP { VEXMAP_0F3A = 3, } VEXMAP; +#define VEX_256 0x40000 + /* This list of x86 opcodes is not intended to be complete. Opcodes are only ** included when needed. Take a look at DynASM or jit.dis_x86 to see the ** whole mess. diff --git a/tests/intrinsic_spec.lua b/tests/intrinsic_spec.lua index 80f39bf5..bf621444 100644 --- a/tests/intrinsic_spec.lua +++ b/tests/intrinsic_spec.lua @@ -6,6 +6,7 @@ typedef float float4 __attribute__((__vector_size__(16))); typedef float float8 __attribute__((__vector_size__(32))); typedef int int4 __attribute__((__vector_size__(16))); typedef uint8_t byte16 __attribute__((__vector_size__(16))); +typedef int64_t long2 __attribute__((__vector_size__(16))); ]] local float4 = ffi.new("float[4]") @@ -862,7 +863,7 @@ it("popcnt", function() end) it("addsd", function() - assert_cdef([[double addsd(double n1, double n2) __mcode("F20F58rM");]], "addsd") + assert_cdef([[double addsd(double n1, double n2) __mcode("F20F58rMv");]], "addsd") local addsd = ffi.C.addsd function test_addsd(n1, n2) @@ -890,7 +891,7 @@ it("addsd", function() assert_equal(6, test_addsd2(3)) --check unfused - ffi.cdef([[double addsduf(double n1, double n2) __mcode("F20F58rR");]]) + ffi.cdef([[double addsduf(double n1, double n2) __mcode("F20F58rRv");]]) addsd = ffi.C.addsduf assert_equal(3, addsd(1, 2)) @@ -898,7 +899,7 @@ it("addsd", function() end) it("addss", function() - assert_cdef([[float addss(float n1, float n2) __mcode("F30F58rM");]], "addss") + assert_cdef([[float addss(float n1, float n2) __mcode("F30F58rMv");]], "addss") local addsd = ffi.C.addss function test_addsd(n1, n2) @@ -922,7 +923,7 @@ it("addss", function() assert_noexit(3, test_addss2, 1.5) --check unfused - ffi.cdef[[float addssuf(float n1, float n2) __mcode("F30F58rR");]] + ffi.cdef[[float addssuf(float n1, float n2) __mcode("F30F58rRv");]] addsd = ffi.C.addssuf assert_equal(3, addsd(1, 2)) @@ -957,6 +958,48 @@ it("shufps", function() assert_equal(vout[3], 1.5) end) +it("vpermilps(avx)", function() + assert_cdef([[float4 vpermilps(float4 v1, int4 control) __mcode("660F380CrMV");]], "vpermilps") + + local v1, v2 = ffi.new("float4", 1, 2, 3, 4) + v2 = ffi.new("int4", 0, 0, 0, 0) + assert_v4eq(ffi.C.vpermilps(v1, v2), 1, 1, 1, 1) + + -- Revese the vector + v2 = ffi.new("int4", 3, 2, 1, 0) + assert_v4eq(ffi.C.vpermilps(v1, v2), 4, 3, 2, 1) +end) + +it("vpslldq(avx)", function() + assert_cdef([[int4 vpslldq_4(int4 v1) __mcode("660F737mUV", 4);]], "vpslldq_4") + + local v = ffi.new("int4", 1, 2, 3, 4) + assert_v4eq(ffi.C.vpslldq_4(v), 0, 1, 2, 3) +end) + +it("vaddps(avx, ymm)", function() + assert_cdef([[float8 vaddps(float8 v1, float8 v2) __mcode("0F58rMV");]], "vaddps") + + local v1 = ffi.new("float8", 1, 2, 3, 4, 5, 6, 7, 8) + local v2 = ffi.new("float8", 1, 1, 1, 1, 1, 1, 1, 1) + + local vout = ffi.C.vaddps(v1, v2) + + for i=0,7 do + assert_equal(vout[i], i+2) + end +end) + +if ffi.arch == "x64" then + --Check VEX.L bit is correcly set when the X opcode flag is specifed on vex opcodes + it("vmovd VEX.L bit(avx)", function() + assert_cdef([[long2 vmovd(int64_t n) __mcode("660F6ErMVX");]], "vmovd") + + local v = ffi.C.vmovd(-1LL) + assert_equal(v[0], -1LL) + end) +end + it("phaddd 4byte opcode", function() ffi.cdef([[int4 phaddd(int4 v1, int4 v2) __mcode("660F3802rM");]])