Compile string.find() and string.match()

This commit is contained in:
Karel Tuma 2015-10-01 22:46:53 +02:00
parent 22a9ed838b
commit cdf0b5dd73
7 changed files with 209 additions and 84 deletions

View File

@ -137,22 +137,6 @@ LJLIB_CF(string_dump)
/* macro to `unsign' a character */
#define uchar(c) ((unsigned char)(c))
#define CAP_UNFINISHED (-1)
#define CAP_POSITION (-2)
typedef struct MatchState {
const char *src_init; /* init of source string */
const char *src_end; /* end (`\0') of source string */
lua_State *L;
int level; /* total number of captures (finished or unfinished) */
int depth;
struct {
const char *init;
ptrdiff_t len;
} capture[LUA_MAXCAPTURES];
} MatchState;
#define L_ESC '%'
static int check_capture(MatchState *ms, int l)
@ -304,7 +288,7 @@ static const char *start_capture(MatchState *ms, const char *s,
const char *res;
int level = ms->level;
if (level >= LUA_MAXCAPTURES) lj_err_caller(ms->L, LJ_ERR_STRCAPN);
ms->capture[level].init = s;
setmref(ms->capture[level].init, s);
ms->capture[level].len = what;
ms->level = level+1;
if ((res=match(ms, s, p)) == NULL) /* match failed? */
@ -317,7 +301,7 @@ static const char *end_capture(MatchState *ms, const char *s,
{
int l = capture_to_close(ms);
const char *res;
ms->capture[l].len = s - ms->capture[l].init; /* close capture */
ms->capture[l].len = s - mref(ms->capture[l].init, char); /* close capture */
if ((res = match(ms, s, p)) == NULL) /* match failed? */
ms->capture[l].len = CAP_UNFINISHED; /* undo capture */
return res;
@ -329,7 +313,7 @@ static const char *match_capture(MatchState *ms, const char *s, int l)
l = check_capture(ms, l);
len = (size_t)ms->capture[l].len;
if ((size_t)(ms->src_end-s) >= len &&
memcmp(ms->capture[l].init, s, len) == 0)
memcmp(mref(ms->capture[l].init, char), s, len) == 0)
return s+len;
else
return NULL;
@ -420,6 +404,50 @@ static const char *match(MatchState *ms, const char *s, const char *p)
return s;
}
/* Match and store the result for JIT code. */
MatchState * lj_str_match(lua_State *L, const char *s, const char *pstr,
MSize slen, int32_t start)
{
MatchState *ms = &G(L)->ms;
int anchor = 0;
MSize st;
const char *sstr;
if (start < 0) start += (int32_t)slen; else start--;
if (start < 0) start = 0;
st = start;
if (st > slen) {
#if LJ_52
return NULL;
#else
st = slen;
#endif
}
sstr = s + st;
if (*pstr == '^') { pstr++; anchor = 1; }
ms->L = L;
ms->src_init = s;
ms->src_end = s + slen;
do { /* Loop through string and try to match the pattern. */
const char *q;
ms->level = ms->depth = 0;
q = match(ms, sstr, pstr);
if (q) {
lua_assert(sstr>=s);
lua_assert(q>=s);
ms->findret1 = (int32_t)(sstr-s+1);
ms->findret2 = (int32_t)(q-s);
/* No captures, simulate one. */
if (!ms->level) {
setmref(ms->capture[0].init, sstr);
ms->capture[0].len = q - sstr;
ms->level = 1;
}
return ms;
}
} while (sstr++ < ms->src_end && !anchor);
return NULL;
}
static void push_onecapture(MatchState *ms, int i, const char *s, const char *e)
{
if (i >= ms->level) {
@ -431,9 +459,9 @@ static void push_onecapture(MatchState *ms, int i, const char *s, const char *e)
ptrdiff_t l = ms->capture[i].len;
if (l == CAP_UNFINISHED) lj_err_caller(ms->L, LJ_ERR_STRCAPU);
if (l == CAP_POSITION)
lua_pushinteger(ms->L, ms->capture[i].init - ms->src_init + 1);
lua_pushinteger(ms->L, mref(ms->capture[i].init,char) - ms->src_init + 1);
else
lua_pushlstring(ms->L, ms->capture[i].init, (size_t)l);
lua_pushlstring(ms->L, mref(ms->capture[i].init,char), (size_t)l);
}
}
@ -452,30 +480,32 @@ static int str_find_aux(lua_State *L, int find)
GCstr *s = lj_lib_checkstr(L, 1);
GCstr *p = lj_lib_checkstr(L, 2);
int32_t start = lj_lib_optint(L, 3, 1);
MSize st;
if (start < 0) start += (int32_t)s->len; else start--;
if (start < 0) start = 0;
st = (MSize)start;
if (st > s->len) {
#if LJ_52
setnilV(L->top-1);
return 1;
#else
st = s->len;
#endif
}
if (find && ((L->base+3 < L->top && tvistruecond(L->base+3)) ||
!lj_str_haspattern(p))) { /* Search for fixed string. */
const char *q = lj_str_find(strdata(s)+st, strdata(p), s->len-st, p->len);
if (q) {
setintV(L->top-2, (int32_t)(q-strdata(s)) + 1);
setintV(L->top-1, (int32_t)(q-strdata(s)) + (int32_t)p->len);
int n = lj_str_find(strdata(s), strdata(p), s->len, p->len,
start);
if (n) {
setintV(L->top-2, n);
setintV(L->top-1, n+p->len-1);
return 2;
}
} else { /* Search for pattern. */
MatchState ms;
const char *pstr = strdata(p);
const char *sstr = strdata(s) + st;
const char *sstr;
MSize st;
if (start < 0) start += (int32_t)s->len; else start--;
if (start < 0) start = 0;
st = (MSize)start;
if (st > s->len) {
#if LJ_52
setnilV(L->top-1);
return 1;
#else
st = s->len;
#endif
}
sstr = strdata(s) + st;
int anchor = 0;
if (*pstr == '^') { pstr++; anchor = 1; }
ms.L = L;
@ -500,12 +530,12 @@ static int str_find_aux(lua_State *L, int find)
return 1;
}
LJLIB_CF(string_find) LJLIB_REC(.)
LJLIB_CF(string_find) LJLIB_REC(string_findmatch 1)
{
return str_find_aux(L, 1);
}
LJLIB_CF(string_match)
LJLIB_CF(string_match) LJLIB_REC(string_findmatch 0)
{
return str_find_aux(L, 0);
}

View File

@ -899,17 +899,62 @@ static void LJ_FASTCALL recff_string_op(jit_State *J, RecordFFData *rd)
J->base[0] = emitir(IRT(IR_BUFSTR, IRT_STR), tr, hdr);
}
static void LJ_FASTCALL recff_string_find(jit_State *J, RecordFFData *rd)
static int recff_emit_captures(jit_State *J, const MatchState *ms,
TRef tr, int off, TRef sptr)
{
TRef trstr = lj_ir_tostr(J, J->base[0]);
TRef trpat = lj_ir_tostr(J, J->base[1]);
TRef trlen = emitir(IRTI(IR_FLOAD), trstr, IRFL_STR_LEN);
TRef tr0 = lj_ir_kint(J, 0);
TRef captures;
int i;
int capoff = offsetof(MatchState, capture[0]);
int capsize = sizeof(ms->capture[0]);
int initoff = offsetof(MatchState, capture[0].init) - capoff;
int lenoff = offsetof(MatchState, capture[0].len) - capoff;
if (!ms->level)
return 0;
captures = emitir(IRT(IR_ADD, IRT_P32), tr, /* IRFL_ not really applicable. */
lj_ir_kint(J, capoff));
for (i = 0; i < ms->level; i++) {
int init = i*capsize+initoff;
if (ms->capture[i].len == CAP_POSITION) {
J->base[off+i] =
emitir(IRT(IR_SUB, IRT_INT),
emitir(IRT(IR_XLOAD, IRT_P32),
emitir(IRT(IR_ADD, IRT_P32), captures, lj_ir_kint(J, init)), 0),
emitir(IRT(IR_SUB, IRT_P32), sptr, lj_ir_kint(J, 1))
); /* IR_SUB */
} else {
int len = i*capsize+lenoff;
J->base[off+i] =
emitir(IRT(IR_SNEW, IRT_STR),
emitir(IRT(IR_XLOAD, IRT_P32),
emitir(IRT(IR_ADD, IRT_P32), captures, lj_ir_kint(J, init)), 0),
emitir(IRT(IR_XLOAD, IRT_INT),
emitir(IRT(IR_ADD, IRT_P32), captures, lj_ir_kint(J, len)), 0)
); /* IR_SNEW */
}
}
return ms->level;
}
static void LJ_FASTCALL recff_string_findmatch(jit_State *J, RecordFFData *rd)
{
int find = rd->data;
TRef tr0, trstr, trsptr, trslen, trpat, trpptr;
TRef trstart;
int32_t start;
GCstr *str = argv2str(J, &rd->argv[0]);
GCstr *pat = argv2str(J, &rd->argv[1]);
int32_t start;
TRef kpat = 0;
int rawfind = 0;
J->needsnap = 1;
tr0 = lj_ir_kint(J, 0);
trstr = lj_ir_tostr(J, J->base[0]);
trpat = lj_ir_tostr(J, J->base[1]);
/* Optional start argument. */
if (tref_isnil(J->base[2])) {
trstart = lj_ir_kint(J, 1);
start = 1;
@ -917,44 +962,56 @@ static void LJ_FASTCALL recff_string_find(jit_State *J, RecordFFData *rd)
trstart = lj_opt_narrow_toint(J, J->base[2]);
start = argv2int(J, &rd->argv[2]);
}
trstart = recff_string_start(J, str, &start, trstart, trlen, tr0);
if ((MSize)start <= str->len) {
emitir(IRTGI(IR_ULE), trstart, trlen);
} else {
emitir(IRTGI(IR_UGT), trstart, trlen);
#if LJ_52
J->base[0] = TREF_NIL;
return;
#else
trstart = trlen;
start = str->len;
#endif
/* Specialize on pattern only if no raw flag is specified. */
if (!find || !(rawfind = (J->base[2] && tref_istruecond(J->base[3])))) {
kpat = lj_ir_kstr(J, pat);
emitir(IRTG(IR_EQ, IRT_STR), trpat, kpat);
trpat = kpat;
}
/* Fixed arg or no pattern matching chars? (Specialized to pattern string.) */
if ((J->base[2] && tref_istruecond(J->base[3])) ||
(emitir(IRTG(IR_EQ, IRT_STR), trpat, lj_ir_kstr(J, pat)),
!lj_str_haspattern(pat))) { /* Search for fixed string. */
TRef trsptr = emitir(IRT(IR_STRREF, IRT_P32), trstr, trstart);
TRef trpptr = emitir(IRT(IR_STRREF, IRT_P32), trpat, tr0);
TRef trslen = emitir(IRTI(IR_SUB), trlen, trstart);
trsptr = emitir(IRT(IR_STRREF, IRT_P32), trstr, tr0);
trslen = emitir(IRTI(IR_FLOAD), trstr, IRFL_STR_LEN);
trpptr = emitir(IRT(IR_STRREF, IRT_P32), trpat, tr0);
if (rawfind || !lj_str_haspattern(pat)) {
TRef trplen = emitir(IRTI(IR_FLOAD), trpat, IRFL_STR_LEN);
TRef tr = lj_ir_call(J, IRCALL_lj_str_find, trsptr, trpptr, trslen, trplen);
TRef tr = lj_ir_call(J, IRCALL_lj_str_find,
trsptr, trpptr, trslen, trplen, trstart);
if (lj_str_find(strdata(str), strdata(pat),
str->len, pat->len, start)) {
emitir(IRTG(IR_NE, IRT_INT), tr, tr0);
if (!find)
J->base[0] = kpat;
else {
J->base[0] = tr;
J->base[1] = emitir(IRTI(IR_ADD), tr, emitir(IRTI(IR_SUB), trplen,
lj_ir_kint(J, 1)));
rd->nres = 2;
}
} else {
emitir(IRTG(IR_EQ, IRT_INT), tr, tr0);
J->base[0] = TREF_NIL;
}
} else {
TRef tr = lj_ir_call(J, IRCALL_lj_str_match,
trsptr, trpptr, trslen, trstart);
TRef trp0 = lj_ir_kkptr(J, NULL);
if (lj_str_find(strdata(str)+(MSize)start, strdata(pat),
str->len-(MSize)start, pat->len)) {
TRef pos;
MatchState *ms = lj_str_match(J->L, strdata(str), strdata(pat),
str->len, start);
if (ms) {
int rpos = 0;
emitir(IRTG(IR_NE, IRT_P32), tr, trp0);
pos = emitir(IRTI(IR_SUB), tr, emitir(IRT(IR_STRREF, IRT_P32), trstr, tr0));
J->base[0] = emitir(IRTI(IR_ADD), pos, lj_ir_kint(J, 1));
J->base[1] = emitir(IRTI(IR_ADD), pos, trplen);
rd->nres = 2;
if (find) {
J->base[0] = emitir(IRTI(IR_FLOAD), tr, IRFL_MS_FINDRET1);
J->base[1] = emitir(IRTI(IR_FLOAD), tr, IRFL_MS_FINDRET2);
rpos = 2;
}
rd->nres = rpos + recff_emit_captures(J, ms, tr, rpos, trsptr);
} else {
emitir(IRTG(IR_EQ, IRT_P32), tr, trp0);
J->base[0] = TREF_NIL;
}
} else { /* Search for pattern. */
recff_nyiu(J, rd);
return;
}
}

View File

@ -209,7 +209,9 @@ IRFPMDEF(FPMENUM)
_(CDATA_PTR, sizeof(GCcdata)) \
_(CDATA_INT, sizeof(GCcdata)) \
_(CDATA_INT64, sizeof(GCcdata)) \
_(CDATA_INT64_4, sizeof(GCcdata) + 4)
_(CDATA_INT64_4, sizeof(GCcdata) + 4) \
_(MS_FINDRET1,offsetof(MatchState, findret1)) \
_(MS_FINDRET2,offsetof(MatchState, findret2)) \
typedef enum {
#define FLENUM(name, ofs) IRFL_##name,

View File

@ -123,7 +123,8 @@ typedef struct CCallInfo {
/* Function definitions for CALL* instructions. */
#define IRCALLDEF(_) \
_(ANY, lj_str_cmp, 2, FN, INT, CCI_NOFPRCLOBBER) \
_(ANY, lj_str_find, 4, N, P32, 0) \
_(ANY, lj_str_match, 5, N, P32, CCI_L) \
_(ANY, lj_str_find, 5, N, INT, 0) \
_(ANY, lj_str_new, 3, S, STR, CCI_L) \
_(ANY, lj_strscan_num, 2, FN, INT, 0) \
_(ANY, lj_strfmt_int, 2, FN, STR, CCI_L) \

View File

@ -590,6 +590,23 @@ typedef struct GCState {
MSize pause; /* Pause between successive GC cycles. */
} GCState;
/* Match state for pattern captures. */
typedef struct MatchState {
const char *src_init; /* Start of source string. */
const char *src_end; /* End (`\0') of source string. */
lua_State *L;
int level; /* Total number of captures (finished or unfinished). */
int depth;
uint32_t findret1;
uint32_t findret2; /* Return indices of string.find(). */
struct {
MRef init;
MSize len;
} capture[LUA_MAXCAPTURES];
} MatchState;
#define CAP_UNFINISHED ((MSize)(-1))
#define CAP_POSITION ((MSize)(-2))
/* Global state, shared by all threads of a Lua universe. */
typedef struct global_State {
GCRef *strhash; /* String hash table (hash chain anchors). */
@ -621,6 +638,7 @@ typedef struct global_State {
MRef jit_base; /* Current JIT code L->base or NULL. */
MRef ctype_state; /* Pointer to C type state. */
GCRef gcroot[GCROOT_MAX]; /* GC roots. */
MatchState ms; /* Capture buffer for JIT mcode. */
} global_State;
#define mainthread(g) (&gcref(g->mainthref)->th)

View File

@ -59,23 +59,38 @@ static LJ_AINLINE int str_fastcmp(const char *a, const char *b, MSize len)
}
/* Find fixed string p inside string s. */
const char *lj_str_find(const char *s, const char *p, MSize slen, MSize plen)
uint32_t lj_str_find(const char *s, const char *p, MSize slen, MSize plen,
int32_t start)
{
const char *os = s;
if (start < 0) start += (int32_t)slen; else start--;
if (start < 0) start = 0;
if (start > slen)
#if LJ_52
return 0;
#else
start = slen;
#endif
s += start;
slen -= start;
if (plen <= slen) {
if (plen == 0) {
return s;
return start+1;
} else {
int c = *(const uint8_t *)p++;
plen--; slen -= plen;
while (slen) {
const char *q = (const char *)memchr(s, c, slen);
if (!q) break;
if (memcmp(q+1, p, plen) == 0) return q;
if (memcmp(q+1, p, plen) == 0) return (q-os+1);
q++; slen -= (MSize)(q-s); s = q;
}
}
}
return NULL;
return 0;
}
/* Check whether a string has a pattern matching character. */

View File

@ -12,9 +12,11 @@
/* String helpers. */
LJ_FUNC int32_t LJ_FASTCALL lj_str_cmp(GCstr *a, GCstr *b);
LJ_FUNC const char *lj_str_find(const char *s, const char *f,
MSize slen, MSize flen);
LJ_FUNC uint32_t lj_str_find(const char *s, const char *p, MSize slen,
MSize plen, int32_t start);
LJ_FUNC int lj_str_haspattern(GCstr *s);
LJ_FUNC MatchState * lj_str_match(lua_State *L, const char *s, const char *p,
MSize slen, int32_t start);
/* String interning. */
LJ_FUNC void lj_str_resize(lua_State *L, MSize newmask);