Limit recursion depth in string.match() et al.

This commit is contained in:
Mike Pall 2012-08-28 21:22:23 +02:00
parent 751cd9d821
commit ff00a78f3a
2 changed files with 41 additions and 26 deletions

View File

@ -148,6 +148,7 @@ typedef struct MatchState {
const char *src_end; /* end (`\0') of source string */ const char *src_end; /* end (`\0') of source string */
lua_State *L; lua_State *L;
int level; /* total number of captures (finished or unfinished) */ int level; /* total number of captures (finished or unfinished) */
int depth;
struct { struct {
const char *init; const char *init;
ptrdiff_t len; ptrdiff_t len;
@ -339,22 +340,26 @@ static const char *match_capture(MatchState *ms, const char *s, int l)
static const char *match(MatchState *ms, const char *s, const char *p) static const char *match(MatchState *ms, const char *s, const char *p)
{ {
if (++ms->depth > LJ_MAX_XLEVEL)
lj_err_caller(ms->L, LJ_ERR_STRPATX);
init: /* using goto's to optimize tail recursion */ init: /* using goto's to optimize tail recursion */
switch (*p) { switch (*p) {
case '(': /* start capture */ case '(': /* start capture */
if (*(p+1) == ')') /* position capture? */ if (*(p+1) == ')') /* position capture? */
return start_capture(ms, s, p+2, CAP_POSITION); s = start_capture(ms, s, p+2, CAP_POSITION);
else else
return start_capture(ms, s, p+1, CAP_UNFINISHED); s = start_capture(ms, s, p+1, CAP_UNFINISHED);
break;
case ')': /* end capture */ case ')': /* end capture */
return end_capture(ms, s, p+1); s = end_capture(ms, s, p+1);
break;
case L_ESC: case L_ESC:
switch (*(p+1)) { switch (*(p+1)) {
case 'b': /* balanced string? */ case 'b': /* balanced string? */
s = matchbalance(ms, s, p+2); s = matchbalance(ms, s, p+2);
if (s == NULL) return NULL; if (s == NULL) break;
p+=4; p+=4;
goto init; /* else return match(ms, s, p+4); */ goto init; /* else s = match(ms, s, p+4); */
case 'f': { /* frontier? */ case 'f': { /* frontier? */
const char *ep; char previous; const char *ep; char previous;
p += 2; p += 2;
@ -363,50 +368,59 @@ static const char *match(MatchState *ms, const char *s, const char *p)
ep = classend(ms, p); /* points to what is next */ ep = classend(ms, p); /* points to what is next */
previous = (s == ms->src_init) ? '\0' : *(s-1); previous = (s == ms->src_init) ? '\0' : *(s-1);
if (matchbracketclass(uchar(previous), p, ep-1) || if (matchbracketclass(uchar(previous), p, ep-1) ||
!matchbracketclass(uchar(*s), p, ep-1)) return NULL; !matchbracketclass(uchar(*s), p, ep-1)) { s = NULL; break; }
p=ep; p=ep;
goto init; /* else return match(ms, s, ep); */ goto init; /* else s = match(ms, s, ep); */
} }
default: default:
if (lj_char_isdigit(uchar(*(p+1)))) { /* capture results (%0-%9)? */ if (lj_char_isdigit(uchar(*(p+1)))) { /* capture results (%0-%9)? */
s = match_capture(ms, s, uchar(*(p+1))); s = match_capture(ms, s, uchar(*(p+1)));
if (s == NULL) return NULL; if (s == NULL) break;
p+=2; p+=2;
goto init; /* else return match(ms, s, p+2) */ goto init; /* else s = match(ms, s, p+2) */
} }
goto dflt; /* case default */ goto dflt; /* case default */
} }
break;
case '\0': /* end of pattern */ case '\0': /* end of pattern */
return s; /* match succeeded */ break; /* match succeeded */
case '$': case '$':
if (*(p+1) == '\0') /* is the `$' the last char in pattern? */ /* is the `$' the last char in pattern? */
return (s == ms->src_end) ? s : NULL; /* check end of string */ if (*(p+1) != '\0') goto dflt;
else if (s != ms->src_end) s = NULL; /* check end of string */
goto dflt; break;
default: dflt: { /* it is a pattern item */ default: dflt: { /* it is a pattern item */
const char *ep = classend(ms, p); /* points to what is next */ const char *ep = classend(ms, p); /* points to what is next */
int m = s<ms->src_end && singlematch(uchar(*s), p, ep); int m = s<ms->src_end && singlematch(uchar(*s), p, ep);
switch (*ep) { switch (*ep) {
case '?': { /* optional */ case '?': { /* optional */
const char *res; const char *res;
if (m && ((res=match(ms, s+1, ep+1)) != NULL)) if (m && ((res=match(ms, s+1, ep+1)) != NULL)) {
return res; s = res;
break;
}
p=ep+1; p=ep+1;
goto init; /* else return match(ms, s, ep+1); */ goto init; /* else s = match(ms, s, ep+1); */
} }
case '*': /* 0 or more repetitions */ case '*': /* 0 or more repetitions */
return max_expand(ms, s, p, ep); s = max_expand(ms, s, p, ep);
break;
case '+': /* 1 or more repetitions */ case '+': /* 1 or more repetitions */
return (m ? max_expand(ms, s+1, p, ep) : NULL); s = (m ? max_expand(ms, s+1, p, ep) : NULL);
break;
case '-': /* 0 or more repetitions (minimum) */ case '-': /* 0 or more repetitions (minimum) */
return min_expand(ms, s, p, ep); s = min_expand(ms, s, p, ep);
break;
default: default:
if (!m) return NULL; if (m) { s++; p=ep; goto init; } /* else s = match(ms, s+1, ep); */
s++; p=ep; s = NULL;
goto init; /* else return match(ms, s+1, ep); */ break;
} }
break;
} }
} }
ms->depth--;
return s;
} }
static const char *lmemfind(const char *s1, size_t l1, static const char *lmemfind(const char *s1, size_t l1,
@ -495,7 +509,7 @@ static int str_find_aux(lua_State *L, int find)
ms.src_end = s+l1; ms.src_end = s+l1;
do { do {
const char *res; const char *res;
ms.level = 0; ms.level = ms.depth = 0;
if ((res=match(&ms, s1, p)) != NULL) { if ((res=match(&ms, s1, p)) != NULL) {
if (find) { if (find) {
lua_pushinteger(L, s1-s+1); /* start */ lua_pushinteger(L, s1-s+1); /* start */
@ -534,7 +548,7 @@ LJLIB_NOREG LJLIB_CF(string_gmatch_aux)
ms.src_end = s + str->len; ms.src_end = s + str->len;
for (; src <= ms.src_end; src++) { for (; src <= ms.src_end; src++) {
const char *e; const char *e;
ms.level = 0; ms.level = ms.depth = 0;
if ((e = match(&ms, src, p)) != NULL) { if ((e = match(&ms, src, p)) != NULL) {
int32_t pos = (int32_t)(e - s); int32_t pos = (int32_t)(e - s);
if (e == src) pos++; /* Ensure progress for empty match. */ if (e == src) pos++; /* Ensure progress for empty match. */
@ -628,7 +642,7 @@ LJLIB_CF(string_gsub)
ms.src_end = src+srcl; ms.src_end = src+srcl;
while (n < max_s) { while (n < max_s) {
const char *e; const char *e;
ms.level = 0; ms.level = ms.depth = 0;
e = match(&ms, src, p); e = match(&ms, src, p);
if (e) { if (e) {
n++; n++;

View File

@ -91,6 +91,7 @@ ERRDEF(STRPATC, "invalid pattern capture")
ERRDEF(STRPATE, "malformed pattern (ends with " LUA_QL("%") ")") ERRDEF(STRPATE, "malformed pattern (ends with " LUA_QL("%") ")")
ERRDEF(STRPATM, "malformed pattern (missing " LUA_QL("]") ")") ERRDEF(STRPATM, "malformed pattern (missing " LUA_QL("]") ")")
ERRDEF(STRPATU, "unbalanced pattern") ERRDEF(STRPATU, "unbalanced pattern")
ERRDEF(STRPATX, "pattern too complex")
ERRDEF(STRCAPI, "invalid capture index") ERRDEF(STRCAPI, "invalid capture index")
ERRDEF(STRCAPN, "too many captures") ERRDEF(STRCAPN, "too many captures")
ERRDEF(STRCAPU, "unfinished capture") ERRDEF(STRCAPU, "unfinished capture")