Added VEX opcode support for intrinsics

This commit is contained in:
fsfod 2016-03-29 11:59:16 +01:00
parent 239f8ad3e6
commit befcdc6e55
7 changed files with 192 additions and 19 deletions

View File

@ -2701,7 +2701,7 @@ static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *sta
uint8_t *in = info.in, *out = info.out;
int spadj = 0;
int dynreg = intrin_regmode(intrins);
Reg rout = RID_NONE, rin = RID_NONE;
Reg rout = RID_NONE, rin = RID_NONE, r3 = RID_NONE;
lj_asm_setup_intrins(J, as);
origtop = as->mctop;
@ -2735,6 +2735,9 @@ static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *sta
info.inset |= pickdynlist(in+inofs, intrins->dyninsz-inofs, scatch);
}
if (dynreg == DYNREG_VEX3)
r3 = reg_rid(in[1]);
if (rin == RID_NONE)
rin = reg_rid(in[0]);
@ -2746,6 +2749,9 @@ static void wrap_intrins(jit_State *J, CIntrinsic *intrins, IntrinWrapState *sta
} else if (dynreg == DYNREG_OPEXT) {
/* Destructive single register opcode */
rout = out[0] = reg_setrid(out[0], rin);
/* Becomes non destructive in vex form*/
if (intrins->flags & INTRINSFLAG_VEX)
r3 = reg_rid(rout);
} else {
scatch = RSET_INIT & ~info.outset;
rset_clear(scatch, info.outcontext);
@ -2868,10 +2874,10 @@ restart:
if (intrins->flags & INTRINSFLAG_CALLED) {
/* emit a call to the target which may be collocated after us */
emit_intrins(as, intrins, rin, (uintptr_t)target);
emit_intrins(as, intrins, rin, (uintptr_t)target, 0);
} else if (dynreg) {
/* Write an opcode to the wrapper */
asmofs = emit_intrins(as, intrins, rin, rout);
asmofs = emit_intrins(as, intrins, rin, rout, r3);
} else {
/* Append the user supplied machine code */
asmofs = asm_mcode(as, state->target, state->targetsz);

View File

@ -730,8 +730,10 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo)
uint32_t dynreg = intrin_regmode(intrins);
RegSet allow;
IRRef lref = 0, rref = 0;
Reg right, dest = RID_NONE;
Reg right, dest = RID_NONE, vvvv = RID_NONE;
int dynrout = intrins->outsz > 0 && intrin_dynrout(intrins);
int vex3 = dynreg == DYNREG_VEX3;
int vexop = intrins->flags & INTRINSFLAG_VEX;
/* Swap to refs to native ordering */
if (dynreg >= DYNREG_SWAPREGS) {
@ -762,10 +764,14 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo)
}
dest = ra_dest(as, ir, allow);
if (dynreg == DYNREG_OPEXT) {
if (vexop) {
vvvv = dest;
} else {
/* Set input register the same as the output since the op is destructive */
right = dest;
}
}
}
if (intrins->dyninsz > 1 && dynreg != DYNREG_TWOSTORE) {
if (lref == rref) {
@ -822,7 +828,8 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo)
/* Handle second input reg for any two input dynamic in register modes
** which isn't DYNREG_INOUT
*/
if (intrins->dyninsz > 1 && ra_noreg(dest)) {
if (intrins->dyninsz > 1 && ((!vex3 && ra_noreg(dest)) ||
(vex3 && ra_noreg(IR(args[1])->r)))) {
Reg r;
allow = reg_torset(in[1]) & ~ininfo->inset;
if (ra_hasreg(right) && right != RID_MRM)
@ -830,7 +837,13 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo)
r = ra_allocref(as, args[1], allow);
in[1] = reg_setrid(in[1], r);
if (!vex3) {
dest = r;
} else if (lref == rref) {
/* update right for same ref */
right = r;
}
}
if (right == RID_MRM) {
@ -848,15 +861,19 @@ static void asm_intrin_opcode(ASMState *as, IRIns *ir, IntrinsInfo *ininfo)
in[0] = reg_setrid(in[0], right);
}
if (vexop && vex3) {
vvvv = reg_rid(in[1]);
}
lua_assert(ra_hasreg(right) && (ra_hasreg(dest) || intrins->dyninsz < 2));
emit_intrins(as, intrins, right, dest);
emit_intrins(as, intrins, right, dest, vvvv);
if (dynreg == DYNREG_INOUT) {
lua_assert(lref);
ra_left(as, dest, lref);
/* no need to load the register since ra_left already did */
in[1] = 0xff;
} else if (dynreg == DYNREG_OPEXT && dynrout) {
} else if (dynreg == DYNREG_OPEXT && dynrout && !vexop) {
/* Handle destructive ONEOPEXT opcodes */
lua_assert(rref);
ra_left(as, dest, rref);
@ -1011,7 +1028,7 @@ static void asm_intrinsic(ASMState *as, IRIns *ir, IRIns *asmend)
r1 = ra_scratch(as, RSET_GPR & ~(ininfo.inset | ininfo.outset));
}
}
emit_intrins(as, intrins, r1, target);
emit_intrins(as, intrins, r1, target, 0);
}
asm_asmsetupargs(as, &ininfo);

View File

@ -33,6 +33,9 @@
/* msb is also set to c5 so we can spot a vex op in op_emit */
#define VEX2 0xc5c5
#define VEX_OP2(o, pp) ((uint32_t)(0xf8c5c5 | ((pp<<16) + (o<<24))))
#define VEX_OP3(o, pp, mode) ((uint32_t)(0x78e0c4 | (mode << 8) | ((pp<<16) + (o<<24))))
/* vvvv bits in the opcode are assumed to be set */
#define VEXOP_SETVVVV(o, rid) ((o) ^ (((rid < RID_MIN_FPR ? \
rid : (rid)-RID_MIN_FPR)) << 19))
@ -40,6 +43,49 @@
/* extract and merge the opcode,vvv,L,pp, W and set VEXMAP_0F */
#define VEX2TO3(op) ((op & 0xff7f0000) | 0xe1c4 | ((op & 0x800000) >> 8))
static int vexpp(uint32_t byte)
{
switch (byte) {
case 0x66:
return VEXPP_66;
case 0xf3:
return VEXPP_f3;
case 0xf2:
return VEXPP_f2;
default:
return VEXPP_0f;
}
}
static int vexmap(uint32_t byte)
{
switch (byte & 0xffff) {
case 0x380F:
return VEXMAP_0F38;
case 0x3a0f:
return VEXMAP_0F3A;
default:
lua_assert((byte & 0xff) == 0x0f);
return VEXMAP_0F;
}
}
uint32_t sse2vex(uint32_t op, uint32_t len, uint32_t vex_w)
{
x86Op vo = op >> 24;
int32_t pp = vexpp((op >> ((4-len) * 8)) & 0xff);
uint32_t mode = vexmap(op >> (len == 4 ? 8 : 16));
if (!vex_w && (len == 2 || (len == 3 && mode == VEXMAP_0F))) {
vo = VEX_OP2(vo, pp);
} else {
vo = VEX_OP3(vo, pp, mode);
if(vex_w)
vo |= VEX_64;
}
return vo;
}
#define emit_i8(as, i) (*--as->mcp = (MCode)(i))
#define emit_i32(as, i) (*(int32_t *)(as->mcp-4) = (i), as->mcp -= 4)
#define emit_u32(as, u) (*(uint32_t *)(as->mcp-4) = (u), as->mcp -= 4)
@ -653,7 +699,7 @@ static void emit_addptr(ASMState *as, Reg r, int32_t ofs)
static MCode* emit_intrins(ASMState *as, CIntrinsic *intrins, Reg r1,
uintptr_t r2)
uintptr_t r2, Reg r3)
{
uint32_t regmode = intrin_regmode(intrins);
if (regmode) {
@ -674,7 +720,15 @@ static MCode* emit_intrins(ASMState *as, CIntrinsic *intrins, Reg r1,
r2 |= OP4B;
}
if (intrins->flags & INTRINSFLAG_VEX) {
x86Op op = intrins->opcode;
if (r3 != RID_NONE) {
op = VEXOP_SETVVVV(op, r3);
}
emit_mrm(as, op, (Reg)r2, r1);
} else {
emit_mrm(as, intrins->opcode, (Reg)r2, r1);
}
if (intrins->flags & INTRINSFLAG_PREFIX) {
*--as->mcp = intrins->prefix;

View File

@ -333,6 +333,12 @@ static int parse_opmode(const char *op, MSize len)
case 'E':
flags |= INTRINSFLAG_EXPLICTREGS;
break;
case 'V':
flags |= INTRINSFLAG_AVXREQ;
case 'v':
/* Use vex encoding of the op if avx/xv2 is supported */
flags |= INTRINSFLAG_VEX;
break;
default:
/* return index of invalid flag */
@ -588,7 +594,12 @@ GCcdata *lj_intrinsic_createffi(CTState *cts, CType *func)
RegSet mod = intrin_getmodrset(cts, intrins);
if (intrins->opcode == 0) {
if (intrin_regmode(intrins) == DYNREG_FIXED) {
lj_err_callermsg(cts->L, "expected non template intrinsic");
} else {
/* Opcode gets set to 0 during parsing if the cpu feature missing */
lj_err_callermsg(cts->L, "Intrinsic not support by cpu");
}
}
/* Build the interpreter wrapper */
@ -606,6 +617,8 @@ GCcdata *lj_intrinsic_createffi(CTState *cts, CType *func)
return cd;
}
extern uint32_t sse2vex(uint32_t op, uint32_t len, uint32_t vex_w);
int lj_intrinsic_fromcdef(lua_State *L, CTypeID fid, GCstr *opstr, uint32_t imm)
{
CTState *cts = ctype_cts(L);
@ -735,6 +748,39 @@ int lj_intrinsic_fromcdef(lua_State *L, CTypeID fid, GCstr *opstr, uint32_t imm)
uint8_t temp = intrins->in[0];
intrins->in[0] = intrins->in[1]; intrins->in[1] = temp;
}
if (intrins->flags & INTRINSFLAG_VEX) {
int vex_w = 0;
/* Set the VEX.W/E bit if the X flag is set */
if (intrins->flags & INTRINSFLAG_REXW) {
vex_w = 1;
intrins->flags &= ~INTRINSFLAG_REXW;
}
if (L2J(L)->flags & JIT_F_AVX1) {
intrins->opcode = sse2vex(intrins->opcode, intrin_oplen(intrins), vex_w);
intrins->flags &= ~INTRINSFLAG_LARGEOP;
/* Switch to non destructive source if the sse reg mode is destructive */
if (intrin_regmode(intrins) == DYNREG_INOUT) {
intrin_setregmode(intrins, DYNREG_VEX3);
}
/* Set the VEX.L bit if the opcode has any 256 bit registers declared */
if (intrins->flags & INTRINSFLAG_VEX256) {
intrins->opcode |= VEX_256;
}
} else if(buildflags & INTRINSFLAG_AVXREQ) {
/* Disable instantiation of the intrinsic since AVX is not support by CPU */
intrins->opcode = 0;
} else {
/* Vex encoding is not optional with these flags */
if ((intrins->flags & INTRINSFLAG_VEX256) || vex_w) {
return 0;
}
/* Use opcode unmodified in its SSE form */
intrins->flags &= ~INTRINSFLAG_VEX;
}
}
#endif
if (intrins->flags & INTRINSFLAG_PREFIX) {

View File

@ -34,6 +34,8 @@ typedef enum REGMODE {
DYNREG_OPEXT,
/* Two input register and one output same register that's same RID the second input */
DYNREG_INOUT,
/* 2 in, 1 out */
DYNREG_VEX3,
/* Two input registers with M dynamic output register */
DYNREG_TWOIN,
@ -44,7 +46,8 @@ typedef enum INTRINSFLAGS {
INTRINSFLAG_REGMODEMASK = 7,
INTRINSFLAG_MEMORYSIDE = 0x08, /* has memory side effects so needs an IR memory barrier */
/* Vex encoded opcode, vvvv may be unused though */
INTRINSFLAG_VEX = 0x10,
/* Intrinsic should be emitted as a naked function that is called */
INTRINSFLAG_CALLED = 0x20,
/* MODRM should always be set as indirect mode */
@ -72,6 +75,8 @@ typedef enum INTRINSFLAGS {
** user supplied code.
*/
INTRINSFLAG_TEMPLATE = 0x40000,
/* Opcode is only supported if the CPU supports AVX */
INTRINSFLAG_AVXREQ = 0x80000,
INTRINSFLAG_CALLEDIND = INTRINSFLAG_CALLED | INTRINSFLAG_INDIRECT
} INTRINSFLAGS;

View File

@ -236,6 +236,8 @@ typedef enum VEXMAP {
VEXMAP_0F3A = 3,
} VEXMAP;
#define VEX_256 0x40000
/* This list of x86 opcodes is not intended to be complete. Opcodes are only
** included when needed. Take a look at DynASM or jit.dis_x86 to see the
** whole mess.

View File

@ -6,6 +6,7 @@ typedef float float4 __attribute__((__vector_size__(16)));
typedef float float8 __attribute__((__vector_size__(32)));
typedef int int4 __attribute__((__vector_size__(16)));
typedef uint8_t byte16 __attribute__((__vector_size__(16)));
typedef int64_t long2 __attribute__((__vector_size__(16)));
]]
local float4 = ffi.new("float[4]")
@ -862,7 +863,7 @@ it("popcnt", function()
end)
it("addsd", function()
assert_cdef([[double addsd(double n1, double n2) __mcode("F20F58rM");]], "addsd")
assert_cdef([[double addsd(double n1, double n2) __mcode("F20F58rMv");]], "addsd")
local addsd = ffi.C.addsd
function test_addsd(n1, n2)
@ -890,7 +891,7 @@ it("addsd", function()
assert_equal(6, test_addsd2(3))
--check unfused
ffi.cdef([[double addsduf(double n1, double n2) __mcode("F20F58rR");]])
ffi.cdef([[double addsduf(double n1, double n2) __mcode("F20F58rRv");]])
addsd = ffi.C.addsduf
assert_equal(3, addsd(1, 2))
@ -898,7 +899,7 @@ it("addsd", function()
end)
it("addss", function()
assert_cdef([[float addss(float n1, float n2) __mcode("F30F58rM");]], "addss")
assert_cdef([[float addss(float n1, float n2) __mcode("F30F58rMv");]], "addss")
local addsd = ffi.C.addss
function test_addsd(n1, n2)
@ -922,7 +923,7 @@ it("addss", function()
assert_noexit(3, test_addss2, 1.5)
--check unfused
ffi.cdef[[float addssuf(float n1, float n2) __mcode("F30F58rR");]]
ffi.cdef[[float addssuf(float n1, float n2) __mcode("F30F58rRv");]]
addsd = ffi.C.addssuf
assert_equal(3, addsd(1, 2))
@ -957,6 +958,48 @@ it("shufps", function()
assert_equal(vout[3], 1.5)
end)
it("vpermilps(avx)", function()
assert_cdef([[float4 vpermilps(float4 v1, int4 control) __mcode("660F380CrMV");]], "vpermilps")
local v1, v2 = ffi.new("float4", 1, 2, 3, 4)
v2 = ffi.new("int4", 0, 0, 0, 0)
assert_v4eq(ffi.C.vpermilps(v1, v2), 1, 1, 1, 1)
-- Revese the vector
v2 = ffi.new("int4", 3, 2, 1, 0)
assert_v4eq(ffi.C.vpermilps(v1, v2), 4, 3, 2, 1)
end)
it("vpslldq(avx)", function()
assert_cdef([[int4 vpslldq_4(int4 v1) __mcode("660F737mUV", 4);]], "vpslldq_4")
local v = ffi.new("int4", 1, 2, 3, 4)
assert_v4eq(ffi.C.vpslldq_4(v), 0, 1, 2, 3)
end)
it("vaddps(avx, ymm)", function()
assert_cdef([[float8 vaddps(float8 v1, float8 v2) __mcode("0F58rMV");]], "vaddps")
local v1 = ffi.new("float8", 1, 2, 3, 4, 5, 6, 7, 8)
local v2 = ffi.new("float8", 1, 1, 1, 1, 1, 1, 1, 1)
local vout = ffi.C.vaddps(v1, v2)
for i=0,7 do
assert_equal(vout[i], i+2)
end
end)
if ffi.arch == "x64" then
--Check VEX.L bit is correcly set when the X opcode flag is specifed on vex opcodes
it("vmovd VEX.L bit(avx)", function()
assert_cdef([[long2 vmovd(int64_t n) __mcode("660F6ErMVX");]], "vmovd")
local v = ffi.C.vmovd(-1LL)
assert_equal(v[0], -1LL)
end)
end
it("phaddd 4byte opcode", function()
ffi.cdef([[int4 phaddd(int4 v1, int4 v2) __mcode("660F3802rM");]])