Limit recursion depth in string.match() et al.

This commit is contained in:
Mike Pall 2012-08-28 21:22:23 +02:00
parent 751cd9d821
commit ff00a78f3a
2 changed files with 41 additions and 26 deletions

View File

@ -148,6 +148,7 @@ typedef struct MatchState {
const char *src_end; /* end (`\0') of source string */
lua_State *L;
int level; /* total number of captures (finished or unfinished) */
int depth;
struct {
const char *init;
ptrdiff_t len;
@ -339,22 +340,26 @@ static const char *match_capture(MatchState *ms, const char *s, int l)
static const char *match(MatchState *ms, const char *s, const char *p)
{
if (++ms->depth > LJ_MAX_XLEVEL)
lj_err_caller(ms->L, LJ_ERR_STRPATX);
init: /* using goto's to optimize tail recursion */
switch (*p) {
case '(': /* start capture */
if (*(p+1) == ')') /* position capture? */
return start_capture(ms, s, p+2, CAP_POSITION);
s = start_capture(ms, s, p+2, CAP_POSITION);
else
return start_capture(ms, s, p+1, CAP_UNFINISHED);
s = start_capture(ms, s, p+1, CAP_UNFINISHED);
break;
case ')': /* end capture */
return end_capture(ms, s, p+1);
s = end_capture(ms, s, p+1);
break;
case L_ESC:
switch (*(p+1)) {
case 'b': /* balanced string? */
s = matchbalance(ms, s, p+2);
if (s == NULL) return NULL;
if (s == NULL) break;
p+=4;
goto init; /* else return match(ms, s, p+4); */
goto init; /* else s = match(ms, s, p+4); */
case 'f': { /* frontier? */
const char *ep; char previous;
p += 2;
@ -363,50 +368,59 @@ static const char *match(MatchState *ms, const char *s, const char *p)
ep = classend(ms, p); /* points to what is next */
previous = (s == ms->src_init) ? '\0' : *(s-1);
if (matchbracketclass(uchar(previous), p, ep-1) ||
!matchbracketclass(uchar(*s), p, ep-1)) return NULL;
!matchbracketclass(uchar(*s), p, ep-1)) { s = NULL; break; }
p=ep;
goto init; /* else return match(ms, s, ep); */
goto init; /* else s = match(ms, s, ep); */
}
default:
if (lj_char_isdigit(uchar(*(p+1)))) { /* capture results (%0-%9)? */
s = match_capture(ms, s, uchar(*(p+1)));
if (s == NULL) return NULL;
if (s == NULL) break;
p+=2;
goto init; /* else return match(ms, s, p+2) */
goto init; /* else s = match(ms, s, p+2) */
}
goto dflt; /* case default */
}
break;
case '\0': /* end of pattern */
return s; /* match succeeded */
break; /* match succeeded */
case '$':
if (*(p+1) == '\0') /* is the `$' the last char in pattern? */
return (s == ms->src_end) ? s : NULL; /* check end of string */
else
goto dflt;
/* is the `$' the last char in pattern? */
if (*(p+1) != '\0') goto dflt;
if (s != ms->src_end) s = NULL; /* check end of string */
break;
default: dflt: { /* it is a pattern item */
const char *ep = classend(ms, p); /* points to what is next */
int m = s<ms->src_end && singlematch(uchar(*s), p, ep);
switch (*ep) {
case '?': { /* optional */
const char *res;
if (m && ((res=match(ms, s+1, ep+1)) != NULL))
return res;
if (m && ((res=match(ms, s+1, ep+1)) != NULL)) {
s = res;
break;
}
p=ep+1;
goto init; /* else return match(ms, s, ep+1); */
goto init; /* else s = match(ms, s, ep+1); */
}
case '*': /* 0 or more repetitions */
return max_expand(ms, s, p, ep);
s = max_expand(ms, s, p, ep);
break;
case '+': /* 1 or more repetitions */
return (m ? max_expand(ms, s+1, p, ep) : NULL);
s = (m ? max_expand(ms, s+1, p, ep) : NULL);
break;
case '-': /* 0 or more repetitions (minimum) */
return min_expand(ms, s, p, ep);
s = min_expand(ms, s, p, ep);
break;
default:
if (!m) return NULL;
s++; p=ep;
goto init; /* else return match(ms, s+1, ep); */
}
if (m) { s++; p=ep; goto init; } /* else s = match(ms, s+1, ep); */
s = NULL;
break;
}
break;
}
}
ms->depth--;
return s;
}
static const char *lmemfind(const char *s1, size_t l1,
@ -495,7 +509,7 @@ static int str_find_aux(lua_State *L, int find)
ms.src_end = s+l1;
do {
const char *res;
ms.level = 0;
ms.level = ms.depth = 0;
if ((res=match(&ms, s1, p)) != NULL) {
if (find) {
lua_pushinteger(L, s1-s+1); /* start */
@ -534,7 +548,7 @@ LJLIB_NOREG LJLIB_CF(string_gmatch_aux)
ms.src_end = s + str->len;
for (; src <= ms.src_end; src++) {
const char *e;
ms.level = 0;
ms.level = ms.depth = 0;
if ((e = match(&ms, src, p)) != NULL) {
int32_t pos = (int32_t)(e - s);
if (e == src) pos++; /* Ensure progress for empty match. */
@ -628,7 +642,7 @@ LJLIB_CF(string_gsub)
ms.src_end = src+srcl;
while (n < max_s) {
const char *e;
ms.level = 0;
ms.level = ms.depth = 0;
e = match(&ms, src, p);
if (e) {
n++;

View File

@ -91,6 +91,7 @@ ERRDEF(STRPATC, "invalid pattern capture")
ERRDEF(STRPATE, "malformed pattern (ends with " LUA_QL("%") ")")
ERRDEF(STRPATM, "malformed pattern (missing " LUA_QL("]") ")")
ERRDEF(STRPATU, "unbalanced pattern")
ERRDEF(STRPATX, "pattern too complex")
ERRDEF(STRCAPI, "invalid capture index")
ERRDEF(STRCAPN, "too many captures")
ERRDEF(STRCAPU, "unfinished capture")