1
0
Fork 0
mirror of https://codeberg.org/forgejo/forgejo.git synced 2025-01-09 15:28:22 -05:00

Protect against NPEs in notifications list (#10879)

Unfortunately there appears to be potential race with notifications
being set before the associated issue has been committed.

This PR adds protection in to the notifications list to log any failures
and remove these notifications from the display.

References #10815 - and prevents the panic but does not completely fix
this.

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
zeripath 2020-03-29 20:51:14 +01:00 committed by GitHub
parent 20d4f9206d
commit d01763ee14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 22 deletions

View file

@ -481,9 +481,9 @@ func (nl NotificationList) getPendingRepoIDs() []int64 {
} }
// LoadRepos loads repositories from database // LoadRepos loads repositories from database
func (nl NotificationList) LoadRepos() (RepositoryList, error) { func (nl NotificationList) LoadRepos() (RepositoryList, []int, error) {
if len(nl) == 0 { if len(nl) == 0 {
return RepositoryList{}, nil return RepositoryList{}, []int{}, nil
} }
var repoIDs = nl.getPendingRepoIDs() var repoIDs = nl.getPendingRepoIDs()
@ -498,7 +498,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
In("id", repoIDs[:limit]). In("id", repoIDs[:limit]).
Rows(new(Repository)) Rows(new(Repository))
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for rows.Next() { for rows.Next() {
@ -506,7 +506,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
err = rows.Scan(&repo) err = rows.Scan(&repo)
if err != nil { if err != nil {
rows.Close() rows.Close()
return nil, err return nil, nil, err
} }
repos[repo.ID] = &repo repos[repo.ID] = &repo
@ -517,14 +517,21 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
repoIDs = repoIDs[limit:] repoIDs = repoIDs[limit:]
} }
failed := []int{}
var reposList = make(RepositoryList, 0, len(repoIDs)) var reposList = make(RepositoryList, 0, len(repoIDs))
for _, notification := range nl { for i, notification := range nl {
if notification.Repository == nil { if notification.Repository == nil {
notification.Repository = repos[notification.RepoID] notification.Repository = repos[notification.RepoID]
} }
if notification.Repository == nil {
log.Error("Notification[%d]: RepoID: %d not found", notification.ID, notification.RepoID)
failed = append(failed, i)
continue
}
var found bool var found bool
for _, r := range reposList { for _, r := range reposList {
if r.ID == notification.Repository.ID { if r.ID == notification.RepoID {
found = true found = true
break break
} }
@ -533,7 +540,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
reposList = append(reposList, notification.Repository) reposList = append(reposList, notification.Repository)
} }
} }
return reposList, nil return reposList, failed, nil
} }
func (nl NotificationList) getPendingIssueIDs() []int64 { func (nl NotificationList) getPendingIssueIDs() []int64 {
@ -550,9 +557,9 @@ func (nl NotificationList) getPendingIssueIDs() []int64 {
} }
// LoadIssues loads issues from database // LoadIssues loads issues from database
func (nl NotificationList) LoadIssues() error { func (nl NotificationList) LoadIssues() ([]int, error) {
if len(nl) == 0 { if len(nl) == 0 {
return nil return []int{}, nil
} }
var issueIDs = nl.getPendingIssueIDs() var issueIDs = nl.getPendingIssueIDs()
@ -567,7 +574,7 @@ func (nl NotificationList) LoadIssues() error {
In("id", issueIDs[:limit]). In("id", issueIDs[:limit]).
Rows(new(Issue)) Rows(new(Issue))
if err != nil { if err != nil {
return err return nil, err
} }
for rows.Next() { for rows.Next() {
@ -575,7 +582,7 @@ func (nl NotificationList) LoadIssues() error {
err = rows.Scan(&issue) err = rows.Scan(&issue)
if err != nil { if err != nil {
rows.Close() rows.Close()
return err return nil, err
} }
issues[issue.ID] = &issue issues[issue.ID] = &issue
@ -586,13 +593,38 @@ func (nl NotificationList) LoadIssues() error {
issueIDs = issueIDs[limit:] issueIDs = issueIDs[limit:]
} }
for _, notification := range nl { failures := []int{}
for i, notification := range nl {
if notification.Issue == nil { if notification.Issue == nil {
notification.Issue = issues[notification.IssueID] notification.Issue = issues[notification.IssueID]
if notification.Issue == nil {
log.Error("Notification[%d]: IssueID: %d Not Found", notification.ID, notification.IssueID)
failures = append(failures, i)
continue
}
notification.Issue.Repo = notification.Repository notification.Issue.Repo = notification.Repository
} }
} }
return nil return failures, nil
}
// Without returns the notification list without the failures
func (nl NotificationList) Without(failures []int) NotificationList {
if failures == nil || len(failures) == 0 {
return nl
}
remaining := make([]*Notification, 0, len(nl))
last := -1
var i int
for _, i = range failures {
remaining = append(remaining, nl[last+1:i]...)
last = i
}
if len(nl) > i {
remaining = append(remaining, nl[i+1:]...)
}
return remaining
} }
func (nl NotificationList) getPendingCommentIDs() []int64 { func (nl NotificationList) getPendingCommentIDs() []int64 {
@ -609,9 +641,9 @@ func (nl NotificationList) getPendingCommentIDs() []int64 {
} }
// LoadComments loads comments from database // LoadComments loads comments from database
func (nl NotificationList) LoadComments() error { func (nl NotificationList) LoadComments() ([]int, error) {
if len(nl) == 0 { if len(nl) == 0 {
return nil return []int{}, nil
} }
var commentIDs = nl.getPendingCommentIDs() var commentIDs = nl.getPendingCommentIDs()
@ -626,7 +658,7 @@ func (nl NotificationList) LoadComments() error {
In("id", commentIDs[:limit]). In("id", commentIDs[:limit]).
Rows(new(Comment)) Rows(new(Comment))
if err != nil { if err != nil {
return err return nil, err
} }
for rows.Next() { for rows.Next() {
@ -634,7 +666,7 @@ func (nl NotificationList) LoadComments() error {
err = rows.Scan(&comment) err = rows.Scan(&comment)
if err != nil { if err != nil {
rows.Close() rows.Close()
return err return nil, err
} }
comments[comment.ID] = &comment comments[comment.ID] = &comment
@ -645,13 +677,19 @@ func (nl NotificationList) LoadComments() error {
commentIDs = commentIDs[limit:] commentIDs = commentIDs[limit:]
} }
for _, notification := range nl { failures := []int{}
for i, notification := range nl {
if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil { if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil {
notification.Comment = comments[notification.CommentID] notification.Comment = comments[notification.CommentID]
if notification.Comment == nil {
log.Error("Notification[%d]: CommentID[%d] failed to load", notification.ID, notification.CommentID)
failures = append(failures, i)
continue
}
notification.Comment.Issue = notification.Issue notification.Comment.Issue = notification.Issue
} }
} }
return nil return failures, nil
} }
// GetNotificationCount returns the notification count for user // GetNotificationCount returns the notification count for user

View file

@ -81,24 +81,39 @@ func Notifications(c *context.Context) {
return return
} }
repos, err := notifications.LoadRepos() failCount := 0
repos, failures, err := notifications.LoadRepos()
if err != nil { if err != nil {
c.ServerError("LoadRepos", err) c.ServerError("LoadRepos", err)
return return
} }
notifications = notifications.Without(failures)
if err := repos.LoadAttributes(); err != nil { if err := repos.LoadAttributes(); err != nil {
c.ServerError("LoadAttributes", err) c.ServerError("LoadAttributes", err)
return return
} }
failCount += len(failures)
if err := notifications.LoadIssues(); err != nil { failures, err = notifications.LoadIssues()
if err != nil {
c.ServerError("LoadIssues", err) c.ServerError("LoadIssues", err)
return return
} }
if err := notifications.LoadComments(); err != nil { notifications = notifications.Without(failures)
failCount += len(failures)
failures, err = notifications.LoadComments()
if err != nil {
c.ServerError("LoadComments", err) c.ServerError("LoadComments", err)
return return
} }
notifications = notifications.Without(failures)
failCount += len(failures)
if failCount > 0 {
c.Flash.Error(fmt.Sprintf("ERROR: %d notifications were removed due to missing parts - check the logs", failCount))
}
title := c.Tr("notifications") title := c.Tr("notifications")
if status == models.NotificationStatusUnread && total > 0 { if status == models.NotificationStatusUnread && total > 0 {