Compile string.find() and string.match()

This commit is contained in:
Karel Tuma 2015-10-01 22:46:53 +02:00
parent 22a9ed838b
commit cdf0b5dd73
7 changed files with 209 additions and 84 deletions

View File

@ -137,22 +137,6 @@ LJLIB_CF(string_dump)
/* macro to `unsign' a character */ /* macro to `unsign' a character */
#define uchar(c) ((unsigned char)(c)) #define uchar(c) ((unsigned char)(c))
#define CAP_UNFINISHED (-1)
#define CAP_POSITION (-2)
typedef struct MatchState {
const char *src_init; /* init of source string */
const char *src_end; /* end (`\0') of source string */
lua_State *L;
int level; /* total number of captures (finished or unfinished) */
int depth;
struct {
const char *init;
ptrdiff_t len;
} capture[LUA_MAXCAPTURES];
} MatchState;
#define L_ESC '%' #define L_ESC '%'
static int check_capture(MatchState *ms, int l) static int check_capture(MatchState *ms, int l)
@ -304,7 +288,7 @@ static const char *start_capture(MatchState *ms, const char *s,
const char *res; const char *res;
int level = ms->level; int level = ms->level;
if (level >= LUA_MAXCAPTURES) lj_err_caller(ms->L, LJ_ERR_STRCAPN); if (level >= LUA_MAXCAPTURES) lj_err_caller(ms->L, LJ_ERR_STRCAPN);
ms->capture[level].init = s; setmref(ms->capture[level].init, s);
ms->capture[level].len = what; ms->capture[level].len = what;
ms->level = level+1; ms->level = level+1;
if ((res=match(ms, s, p)) == NULL) /* match failed? */ if ((res=match(ms, s, p)) == NULL) /* match failed? */
@ -317,7 +301,7 @@ static const char *end_capture(MatchState *ms, const char *s,
{ {
int l = capture_to_close(ms); int l = capture_to_close(ms);
const char *res; const char *res;
ms->capture[l].len = s - ms->capture[l].init; /* close capture */ ms->capture[l].len = s - mref(ms->capture[l].init, char); /* close capture */
if ((res = match(ms, s, p)) == NULL) /* match failed? */ if ((res = match(ms, s, p)) == NULL) /* match failed? */
ms->capture[l].len = CAP_UNFINISHED; /* undo capture */ ms->capture[l].len = CAP_UNFINISHED; /* undo capture */
return res; return res;
@ -329,7 +313,7 @@ static const char *match_capture(MatchState *ms, const char *s, int l)
l = check_capture(ms, l); l = check_capture(ms, l);
len = (size_t)ms->capture[l].len; len = (size_t)ms->capture[l].len;
if ((size_t)(ms->src_end-s) >= len && if ((size_t)(ms->src_end-s) >= len &&
memcmp(ms->capture[l].init, s, len) == 0) memcmp(mref(ms->capture[l].init, char), s, len) == 0)
return s+len; return s+len;
else else
return NULL; return NULL;
@ -420,6 +404,50 @@ static const char *match(MatchState *ms, const char *s, const char *p)
return s; return s;
} }
/* Match and store the result for JIT code. */
MatchState * lj_str_match(lua_State *L, const char *s, const char *pstr,
MSize slen, int32_t start)
{
MatchState *ms = &G(L)->ms;
int anchor = 0;
MSize st;
const char *sstr;
if (start < 0) start += (int32_t)slen; else start--;
if (start < 0) start = 0;
st = start;
if (st > slen) {
#if LJ_52
return NULL;
#else
st = slen;
#endif
}
sstr = s + st;
if (*pstr == '^') { pstr++; anchor = 1; }
ms->L = L;
ms->src_init = s;
ms->src_end = s + slen;
do { /* Loop through string and try to match the pattern. */
const char *q;
ms->level = ms->depth = 0;
q = match(ms, sstr, pstr);
if (q) {
lua_assert(sstr>=s);
lua_assert(q>=s);
ms->findret1 = (int32_t)(sstr-s+1);
ms->findret2 = (int32_t)(q-s);
/* No captures, simulate one. */
if (!ms->level) {
setmref(ms->capture[0].init, sstr);
ms->capture[0].len = q - sstr;
ms->level = 1;
}
return ms;
}
} while (sstr++ < ms->src_end && !anchor);
return NULL;
}
static void push_onecapture(MatchState *ms, int i, const char *s, const char *e) static void push_onecapture(MatchState *ms, int i, const char *s, const char *e)
{ {
if (i >= ms->level) { if (i >= ms->level) {
@ -431,9 +459,9 @@ static void push_onecapture(MatchState *ms, int i, const char *s, const char *e)
ptrdiff_t l = ms->capture[i].len; ptrdiff_t l = ms->capture[i].len;
if (l == CAP_UNFINISHED) lj_err_caller(ms->L, LJ_ERR_STRCAPU); if (l == CAP_UNFINISHED) lj_err_caller(ms->L, LJ_ERR_STRCAPU);
if (l == CAP_POSITION) if (l == CAP_POSITION)
lua_pushinteger(ms->L, ms->capture[i].init - ms->src_init + 1); lua_pushinteger(ms->L, mref(ms->capture[i].init,char) - ms->src_init + 1);
else else
lua_pushlstring(ms->L, ms->capture[i].init, (size_t)l); lua_pushlstring(ms->L, mref(ms->capture[i].init,char), (size_t)l);
} }
} }
@ -452,6 +480,19 @@ static int str_find_aux(lua_State *L, int find)
GCstr *s = lj_lib_checkstr(L, 1); GCstr *s = lj_lib_checkstr(L, 1);
GCstr *p = lj_lib_checkstr(L, 2); GCstr *p = lj_lib_checkstr(L, 2);
int32_t start = lj_lib_optint(L, 3, 1); int32_t start = lj_lib_optint(L, 3, 1);
if (find && ((L->base+3 < L->top && tvistruecond(L->base+3)) ||
!lj_str_haspattern(p))) { /* Search for fixed string. */
int n = lj_str_find(strdata(s), strdata(p), s->len, p->len,
start);
if (n) {
setintV(L->top-2, n);
setintV(L->top-1, n+p->len-1);
return 2;
}
} else { /* Search for pattern. */
MatchState ms;
const char *pstr = strdata(p);
const char *sstr;
MSize st; MSize st;
if (start < 0) start += (int32_t)s->len; else start--; if (start < 0) start += (int32_t)s->len; else start--;
if (start < 0) start = 0; if (start < 0) start = 0;
@ -464,18 +505,7 @@ static int str_find_aux(lua_State *L, int find)
st = s->len; st = s->len;
#endif #endif
} }
if (find && ((L->base+3 < L->top && tvistruecond(L->base+3)) || sstr = strdata(s) + st;
!lj_str_haspattern(p))) { /* Search for fixed string. */
const char *q = lj_str_find(strdata(s)+st, strdata(p), s->len-st, p->len);
if (q) {
setintV(L->top-2, (int32_t)(q-strdata(s)) + 1);
setintV(L->top-1, (int32_t)(q-strdata(s)) + (int32_t)p->len);
return 2;
}
} else { /* Search for pattern. */
MatchState ms;
const char *pstr = strdata(p);
const char *sstr = strdata(s) + st;
int anchor = 0; int anchor = 0;
if (*pstr == '^') { pstr++; anchor = 1; } if (*pstr == '^') { pstr++; anchor = 1; }
ms.L = L; ms.L = L;
@ -500,12 +530,12 @@ static int str_find_aux(lua_State *L, int find)
return 1; return 1;
} }
LJLIB_CF(string_find) LJLIB_REC(.) LJLIB_CF(string_find) LJLIB_REC(string_findmatch 1)
{ {
return str_find_aux(L, 1); return str_find_aux(L, 1);
} }
LJLIB_CF(string_match) LJLIB_CF(string_match) LJLIB_REC(string_findmatch 0)
{ {
return str_find_aux(L, 0); return str_find_aux(L, 0);
} }

View File

@ -899,17 +899,62 @@ static void LJ_FASTCALL recff_string_op(jit_State *J, RecordFFData *rd)
J->base[0] = emitir(IRT(IR_BUFSTR, IRT_STR), tr, hdr); J->base[0] = emitir(IRT(IR_BUFSTR, IRT_STR), tr, hdr);
} }
static void LJ_FASTCALL recff_string_find(jit_State *J, RecordFFData *rd) static int recff_emit_captures(jit_State *J, const MatchState *ms,
TRef tr, int off, TRef sptr)
{ {
TRef trstr = lj_ir_tostr(J, J->base[0]); TRef captures;
TRef trpat = lj_ir_tostr(J, J->base[1]); int i;
TRef trlen = emitir(IRTI(IR_FLOAD), trstr, IRFL_STR_LEN); int capoff = offsetof(MatchState, capture[0]);
TRef tr0 = lj_ir_kint(J, 0); int capsize = sizeof(ms->capture[0]);
int initoff = offsetof(MatchState, capture[0].init) - capoff;
int lenoff = offsetof(MatchState, capture[0].len) - capoff;
if (!ms->level)
return 0;
captures = emitir(IRT(IR_ADD, IRT_P32), tr, /* IRFL_ not really applicable. */
lj_ir_kint(J, capoff));
for (i = 0; i < ms->level; i++) {
int init = i*capsize+initoff;
if (ms->capture[i].len == CAP_POSITION) {
J->base[off+i] =
emitir(IRT(IR_SUB, IRT_INT),
emitir(IRT(IR_XLOAD, IRT_P32),
emitir(IRT(IR_ADD, IRT_P32), captures, lj_ir_kint(J, init)), 0),
emitir(IRT(IR_SUB, IRT_P32), sptr, lj_ir_kint(J, 1))
); /* IR_SUB */
} else {
int len = i*capsize+lenoff;
J->base[off+i] =
emitir(IRT(IR_SNEW, IRT_STR),
emitir(IRT(IR_XLOAD, IRT_P32),
emitir(IRT(IR_ADD, IRT_P32), captures, lj_ir_kint(J, init)), 0),
emitir(IRT(IR_XLOAD, IRT_INT),
emitir(IRT(IR_ADD, IRT_P32), captures, lj_ir_kint(J, len)), 0)
); /* IR_SNEW */
}
}
return ms->level;
}
static void LJ_FASTCALL recff_string_findmatch(jit_State *J, RecordFFData *rd)
{
int find = rd->data;
TRef tr0, trstr, trsptr, trslen, trpat, trpptr;
TRef trstart; TRef trstart;
int32_t start;
GCstr *str = argv2str(J, &rd->argv[0]); GCstr *str = argv2str(J, &rd->argv[0]);
GCstr *pat = argv2str(J, &rd->argv[1]); GCstr *pat = argv2str(J, &rd->argv[1]);
int32_t start; TRef kpat = 0;
int rawfind = 0;
J->needsnap = 1; J->needsnap = 1;
tr0 = lj_ir_kint(J, 0);
trstr = lj_ir_tostr(J, J->base[0]);
trpat = lj_ir_tostr(J, J->base[1]);
/* Optional start argument. */
if (tref_isnil(J->base[2])) { if (tref_isnil(J->base[2])) {
trstart = lj_ir_kint(J, 1); trstart = lj_ir_kint(J, 1);
start = 1; start = 1;
@ -917,44 +962,56 @@ static void LJ_FASTCALL recff_string_find(jit_State *J, RecordFFData *rd)
trstart = lj_opt_narrow_toint(J, J->base[2]); trstart = lj_opt_narrow_toint(J, J->base[2]);
start = argv2int(J, &rd->argv[2]); start = argv2int(J, &rd->argv[2]);
} }
trstart = recff_string_start(J, str, &start, trstart, trlen, tr0);
if ((MSize)start <= str->len) { /* Specialize on pattern only if no raw flag is specified. */
emitir(IRTGI(IR_ULE), trstart, trlen); if (!find || !(rawfind = (J->base[2] && tref_istruecond(J->base[3])))) {
} else { kpat = lj_ir_kstr(J, pat);
emitir(IRTGI(IR_UGT), trstart, trlen); emitir(IRTG(IR_EQ, IRT_STR), trpat, kpat);
#if LJ_52 trpat = kpat;
J->base[0] = TREF_NIL;
return;
#else
trstart = trlen;
start = str->len;
#endif
} }
/* Fixed arg or no pattern matching chars? (Specialized to pattern string.) */
if ((J->base[2] && tref_istruecond(J->base[3])) || trsptr = emitir(IRT(IR_STRREF, IRT_P32), trstr, tr0);
(emitir(IRTG(IR_EQ, IRT_STR), trpat, lj_ir_kstr(J, pat)), trslen = emitir(IRTI(IR_FLOAD), trstr, IRFL_STR_LEN);
!lj_str_haspattern(pat))) { /* Search for fixed string. */ trpptr = emitir(IRT(IR_STRREF, IRT_P32), trpat, tr0);
TRef trsptr = emitir(IRT(IR_STRREF, IRT_P32), trstr, trstart);
TRef trpptr = emitir(IRT(IR_STRREF, IRT_P32), trpat, tr0); if (rawfind || !lj_str_haspattern(pat)) {
TRef trslen = emitir(IRTI(IR_SUB), trlen, trstart);
TRef trplen = emitir(IRTI(IR_FLOAD), trpat, IRFL_STR_LEN); TRef trplen = emitir(IRTI(IR_FLOAD), trpat, IRFL_STR_LEN);
TRef tr = lj_ir_call(J, IRCALL_lj_str_find, trsptr, trpptr, trslen, trplen); TRef tr = lj_ir_call(J, IRCALL_lj_str_find,
TRef trp0 = lj_ir_kkptr(J, NULL); trsptr, trpptr, trslen, trplen, trstart);
if (lj_str_find(strdata(str)+(MSize)start, strdata(pat), if (lj_str_find(strdata(str), strdata(pat),
str->len-(MSize)start, pat->len)) { str->len, pat->len, start)) {
TRef pos; emitir(IRTG(IR_NE, IRT_INT), tr, tr0);
emitir(IRTG(IR_NE, IRT_P32), tr, trp0); if (!find)
pos = emitir(IRTI(IR_SUB), tr, emitir(IRT(IR_STRREF, IRT_P32), trstr, tr0)); J->base[0] = kpat;
J->base[0] = emitir(IRTI(IR_ADD), pos, lj_ir_kint(J, 1)); else {
J->base[1] = emitir(IRTI(IR_ADD), pos, trplen); J->base[0] = tr;
J->base[1] = emitir(IRTI(IR_ADD), tr, emitir(IRTI(IR_SUB), trplen,
lj_ir_kint(J, 1)));
rd->nres = 2; rd->nres = 2;
}
} else {
emitir(IRTG(IR_EQ, IRT_INT), tr, tr0);
J->base[0] = TREF_NIL;
}
} else {
TRef tr = lj_ir_call(J, IRCALL_lj_str_match,
trsptr, trpptr, trslen, trstart);
TRef trp0 = lj_ir_kkptr(J, NULL);
MatchState *ms = lj_str_match(J->L, strdata(str), strdata(pat),
str->len, start);
if (ms) {
int rpos = 0;
emitir(IRTG(IR_NE, IRT_P32), tr, trp0);
if (find) {
J->base[0] = emitir(IRTI(IR_FLOAD), tr, IRFL_MS_FINDRET1);
J->base[1] = emitir(IRTI(IR_FLOAD), tr, IRFL_MS_FINDRET2);
rpos = 2;
}
rd->nres = rpos + recff_emit_captures(J, ms, tr, rpos, trsptr);
} else { } else {
emitir(IRTG(IR_EQ, IRT_P32), tr, trp0); emitir(IRTG(IR_EQ, IRT_P32), tr, trp0);
J->base[0] = TREF_NIL; J->base[0] = TREF_NIL;
} }
} else { /* Search for pattern. */
recff_nyiu(J, rd);
return;
} }
} }

View File

@ -209,7 +209,9 @@ IRFPMDEF(FPMENUM)
_(CDATA_PTR, sizeof(GCcdata)) \ _(CDATA_PTR, sizeof(GCcdata)) \
_(CDATA_INT, sizeof(GCcdata)) \ _(CDATA_INT, sizeof(GCcdata)) \
_(CDATA_INT64, sizeof(GCcdata)) \ _(CDATA_INT64, sizeof(GCcdata)) \
_(CDATA_INT64_4, sizeof(GCcdata) + 4) _(CDATA_INT64_4, sizeof(GCcdata) + 4) \
_(MS_FINDRET1,offsetof(MatchState, findret1)) \
_(MS_FINDRET2,offsetof(MatchState, findret2)) \
typedef enum { typedef enum {
#define FLENUM(name, ofs) IRFL_##name, #define FLENUM(name, ofs) IRFL_##name,

View File

@ -123,7 +123,8 @@ typedef struct CCallInfo {
/* Function definitions for CALL* instructions. */ /* Function definitions for CALL* instructions. */
#define IRCALLDEF(_) \ #define IRCALLDEF(_) \
_(ANY, lj_str_cmp, 2, FN, INT, CCI_NOFPRCLOBBER) \ _(ANY, lj_str_cmp, 2, FN, INT, CCI_NOFPRCLOBBER) \
_(ANY, lj_str_find, 4, N, P32, 0) \ _(ANY, lj_str_match, 5, N, P32, CCI_L) \
_(ANY, lj_str_find, 5, N, INT, 0) \
_(ANY, lj_str_new, 3, S, STR, CCI_L) \ _(ANY, lj_str_new, 3, S, STR, CCI_L) \
_(ANY, lj_strscan_num, 2, FN, INT, 0) \ _(ANY, lj_strscan_num, 2, FN, INT, 0) \
_(ANY, lj_strfmt_int, 2, FN, STR, CCI_L) \ _(ANY, lj_strfmt_int, 2, FN, STR, CCI_L) \

View File

@ -590,6 +590,23 @@ typedef struct GCState {
MSize pause; /* Pause between successive GC cycles. */ MSize pause; /* Pause between successive GC cycles. */
} GCState; } GCState;
/* Match state for pattern captures. */
typedef struct MatchState {
const char *src_init; /* Start of source string. */
const char *src_end; /* End (`\0') of source string. */
lua_State *L;
int level; /* Total number of captures (finished or unfinished). */
int depth;
uint32_t findret1;
uint32_t findret2; /* Return indices of string.find(). */
struct {
MRef init;
MSize len;
} capture[LUA_MAXCAPTURES];
} MatchState;
#define CAP_UNFINISHED ((MSize)(-1))
#define CAP_POSITION ((MSize)(-2))
/* Global state, shared by all threads of a Lua universe. */ /* Global state, shared by all threads of a Lua universe. */
typedef struct global_State { typedef struct global_State {
GCRef *strhash; /* String hash table (hash chain anchors). */ GCRef *strhash; /* String hash table (hash chain anchors). */
@ -621,6 +638,7 @@ typedef struct global_State {
MRef jit_base; /* Current JIT code L->base or NULL. */ MRef jit_base; /* Current JIT code L->base or NULL. */
MRef ctype_state; /* Pointer to C type state. */ MRef ctype_state; /* Pointer to C type state. */
GCRef gcroot[GCROOT_MAX]; /* GC roots. */ GCRef gcroot[GCROOT_MAX]; /* GC roots. */
MatchState ms; /* Capture buffer for JIT mcode. */
} global_State; } global_State;
#define mainthread(g) (&gcref(g->mainthref)->th) #define mainthread(g) (&gcref(g->mainthref)->th)

View File

@ -59,23 +59,38 @@ static LJ_AINLINE int str_fastcmp(const char *a, const char *b, MSize len)
} }
/* Find fixed string p inside string s. */ /* Find fixed string p inside string s. */
const char *lj_str_find(const char *s, const char *p, MSize slen, MSize plen) uint32_t lj_str_find(const char *s, const char *p, MSize slen, MSize plen,
int32_t start)
{ {
const char *os = s;
if (start < 0) start += (int32_t)slen; else start--;
if (start < 0) start = 0;
if (start > slen)
#if LJ_52
return 0;
#else
start = slen;
#endif
s += start;
slen -= start;
if (plen <= slen) { if (plen <= slen) {
if (plen == 0) { if (plen == 0) {
return s; return start+1;
} else { } else {
int c = *(const uint8_t *)p++; int c = *(const uint8_t *)p++;
plen--; slen -= plen; plen--; slen -= plen;
while (slen) { while (slen) {
const char *q = (const char *)memchr(s, c, slen); const char *q = (const char *)memchr(s, c, slen);
if (!q) break; if (!q) break;
if (memcmp(q+1, p, plen) == 0) return q; if (memcmp(q+1, p, plen) == 0) return (q-os+1);
q++; slen -= (MSize)(q-s); s = q; q++; slen -= (MSize)(q-s); s = q;
} }
} }
} }
return NULL; return 0;
} }
/* Check whether a string has a pattern matching character. */ /* Check whether a string has a pattern matching character. */

View File

@ -12,9 +12,11 @@
/* String helpers. */ /* String helpers. */
LJ_FUNC int32_t LJ_FASTCALL lj_str_cmp(GCstr *a, GCstr *b); LJ_FUNC int32_t LJ_FASTCALL lj_str_cmp(GCstr *a, GCstr *b);
LJ_FUNC const char *lj_str_find(const char *s, const char *f, LJ_FUNC uint32_t lj_str_find(const char *s, const char *p, MSize slen,
MSize slen, MSize flen); MSize plen, int32_t start);
LJ_FUNC int lj_str_haspattern(GCstr *s); LJ_FUNC int lj_str_haspattern(GCstr *s);
LJ_FUNC MatchState * lj_str_match(lua_State *L, const char *s, const char *p,
MSize slen, int32_t start);
/* String interning. */ /* String interning. */
LJ_FUNC void lj_str_resize(lua_State *L, MSize newmask); LJ_FUNC void lj_str_resize(lua_State *L, MSize newmask);