1
0
Fork 0
mirror of https://codeberg.org/forgejo/forgejo.git synced 2024-11-27 09:11:53 -05:00
forgejo/vendor/github.com/mrjones/oauth/oauth.go
2021-02-28 18:08:33 -05:00

1419 lines
42 KiB
Go
Vendored

// OAuth 1.0 consumer implementation.
// See http://www.oauth.net and RFC 5849
//
// There are typically three parties involved in an OAuth exchange:
// (1) The "Service Provider" (e.g. Google, Twitter, NetFlix) who operates the
// service where the data resides.
// (2) The "End User" who owns that data, and wants to grant access to a third-party.
// (3) That third-party who wants access to the data (after first being authorized by
// the user). This third-party is referred to as the "Consumer" in OAuth
// terminology.
//
// This library is designed to help implement the third-party consumer by handling the
// low-level authentication tasks, and allowing for authenticated requests to the
// service provider on behalf of the user.
//
// Caveats:
// - Currently only supports HMAC and RSA signatures.
// - Currently only supports SHA1 and SHA256 hashes.
// - Currently only supports OAuth 1.0
//
// Overview of how to use this library:
// (1) First create a new Consumer instance with the NewConsumer function
// (2) Get a RequestToken, and "authorization url" from GetRequestTokenAndUrl()
// (3) Save the RequestToken, you will need it again in step 6.
// (4) Redirect the user to the "authorization url" from step 2, where they will
// authorize your access to the service provider.
// (5) Wait. You will be called back on the CallbackUrl that you provide, and you
// will recieve a "verification code".
// (6) Call AuthorizeToken() with the RequestToken from step 2 and the
// "verification code" from step 5.
// (7) You will get back an AccessToken. Save this for as long as you need access
// to the user's data, and treat it like a password; it is a secret.
// (8) You can now throw away the RequestToken from step 2, it is no longer
// necessary.
// (9) Call "MakeHttpClient" using the AccessToken from step 7 to get an
// HTTP client which can access protected resources.
package oauth
import (
"bytes"
"crypto"
"crypto/hmac"
cryptoRand "crypto/rand"
"crypto/rsa"
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"mime/multipart"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
)
const (
OAUTH_VERSION = "1.0"
SIGNATURE_METHOD_HMAC = "HMAC-"
SIGNATURE_METHOD_RSA = "RSA-"
HTTP_AUTH_HEADER = "Authorization"
OAUTH_HEADER = "OAuth "
BODY_HASH_PARAM = "oauth_body_hash"
CALLBACK_PARAM = "oauth_callback"
CONSUMER_KEY_PARAM = "oauth_consumer_key"
NONCE_PARAM = "oauth_nonce"
SESSION_HANDLE_PARAM = "oauth_session_handle"
SIGNATURE_METHOD_PARAM = "oauth_signature_method"
SIGNATURE_PARAM = "oauth_signature"
TIMESTAMP_PARAM = "oauth_timestamp"
TOKEN_PARAM = "oauth_token"
TOKEN_SECRET_PARAM = "oauth_token_secret"
VERIFIER_PARAM = "oauth_verifier"
VERSION_PARAM = "oauth_version"
)
var HASH_METHOD_MAP = map[crypto.Hash]string{
crypto.SHA1: "SHA1",
crypto.SHA256: "SHA256",
}
// TODO(mrjones) Do we definitely want separate "Request" and "Access" token classes?
// They're identical structurally, but used for different purposes.
type RequestToken struct {
Token string
Secret string
}
type AccessToken struct {
Token string
Secret string
AdditionalData map[string]string
}
type DataLocation int
const (
LOC_BODY DataLocation = iota + 1
LOC_URL
LOC_MULTIPART
LOC_JSON
LOC_XML
)
// Information about how to contact the service provider (see #1 above).
// You usually find all of these URLs by reading the documentation for the service
// that you're trying to connect to.
// Some common examples are:
// (1) Google, standard APIs:
// http://code.google.com/apis/accounts/docs/OAuth_ref.html
// - RequestTokenUrl: https://www.google.com/accounts/OAuthGetRequestToken
// - AuthorizeTokenUrl: https://www.google.com/accounts/OAuthAuthorizeToken
// - AccessTokenUrl: https://www.google.com/accounts/OAuthGetAccessToken
// Note: Some Google APIs (for example, Google Latitude) use different values for
// one or more of those URLs.
// (2) Twitter API:
// http://dev.twitter.com/pages/auth
// - RequestTokenUrl: http://api.twitter.com/oauth/request_token
// - AuthorizeTokenUrl: https://api.twitter.com/oauth/authorize
// - AccessTokenUrl: https://api.twitter.com/oauth/access_token
// (3) NetFlix API:
// http://developer.netflix.com/docs/Security
// - RequestTokenUrl: http://api.netflix.com/oauth/request_token
// - AuthroizeTokenUrl: https://api-user.netflix.com/oauth/login
// - AccessTokenUrl: http://api.netflix.com/oauth/access_token
// Set HttpMethod if the service provider requires a different HTTP method
// to be used for OAuth token requests
type ServiceProvider struct {
RequestTokenUrl string
AuthorizeTokenUrl string
AccessTokenUrl string
HttpMethod string
BodyHash bool
IgnoreTimestamp bool
// Enables non spec-compliant behavior:
// Allow parameters to be passed in the query string rather
// than the body.
// See https://github.com/mrjones/oauth/pull/63
SignQueryParams bool
}
func (sp *ServiceProvider) httpMethod() string {
if sp.HttpMethod != "" {
return sp.HttpMethod
}
return "GET"
}
// lockedNonceGenerator wraps a non-reentrant random number generator with a
// lock
type lockedNonceGenerator struct {
nonceGenerator nonceGenerator
lock sync.Mutex
}
func newLockedNonceGenerator(c clock) *lockedNonceGenerator {
return &lockedNonceGenerator{
nonceGenerator: rand.New(rand.NewSource(c.Nanos())),
}
}
func (n *lockedNonceGenerator) Int63() int64 {
n.lock.Lock()
r := n.nonceGenerator.Int63()
n.lock.Unlock()
return r
}
// Consumers are stateless, you can call the various methods (GetRequestTokenAndUrl,
// AuthorizeToken, and Get) on various different instances of Consumers *as long as
// they were set up in the same way.* It is up to you, as the caller to persist the
// necessary state (RequestTokens and AccessTokens).
type Consumer struct {
// Some ServiceProviders require extra parameters to be passed for various reasons.
// For example Google APIs require you to set a scope= parameter to specify how much
// access is being granted. The proper values for scope= depend on the service:
// For more, see: http://code.google.com/apis/accounts/docs/OAuth.html#prepScope
AdditionalParams map[string]string
// The rest of this class is configured via the NewConsumer function.
consumerKey string
serviceProvider ServiceProvider
// Some APIs (e.g. Netflix) aren't quite standard OAuth, and require passing
// additional parameters when authorizing the request token. For most APIs
// this field can be ignored. For Netflix, do something like:
// consumer.AdditionalAuthorizationUrlParams = map[string]string{
// "application_name": "YourAppName",
// "oauth_consumer_key": "YourConsumerKey",
// }
AdditionalAuthorizationUrlParams map[string]string
debug bool
// Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary
HttpClient HttpClient
// Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with
// requests. (like "Accept" to specify the response type as XML or JSON) Note that this
// will only *add* headers, not set existing ones.
AdditionalHeaders map[string][]string
// Private seams for mocking dependencies when testing
clock clock
// Seeded generators are not reentrant
nonceGenerator nonceGenerator
signer signer
}
func newConsumer(consumerKey string, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer {
clock := &defaultClock{}
if httpClient == nil {
httpClient = &http.Client{}
}
return &Consumer{
consumerKey: consumerKey,
serviceProvider: serviceProvider,
clock: clock,
HttpClient: httpClient,
nonceGenerator: newLockedNonceGenerator(clock),
AdditionalParams: make(map[string]string),
AdditionalAuthorizationUrlParams: make(map[string]string),
}
}
// Creates a new Consumer instance, with a HMAC-SHA1 signer
// - consumerKey and consumerSecret:
// values you should obtain from the ServiceProvider when you register your
// application.
//
// - serviceProvider:
// see the documentation for ServiceProvider for how to create this.
//
func NewConsumer(consumerKey string, consumerSecret string,
serviceProvider ServiceProvider) *Consumer {
consumer := newConsumer(consumerKey, serviceProvider, nil)
consumer.signer = &HMACSigner{
consumerSecret: consumerSecret,
hashFunc: crypto.SHA1,
}
return consumer
}
// Creates a new Consumer instance, with a HMAC-SHA1 signer
// - consumerKey and consumerSecret:
// values you should obtain from the ServiceProvider when you register your
// application.
//
// - serviceProvider:
// see the documentation for ServiceProvider for how to create this.
//
// - httpClient:
// Provides a custom implementation of the httpClient used under the hood
// to make the request. This is especially useful if you want to use
// Google App Engine.
//
func NewCustomHttpClientConsumer(consumerKey string, consumerSecret string,
serviceProvider ServiceProvider, httpClient *http.Client) *Consumer {
consumer := newConsumer(consumerKey, serviceProvider, httpClient)
consumer.signer = &HMACSigner{
consumerSecret: consumerSecret,
hashFunc: crypto.SHA1,
}
return consumer
}
// Creates a new Consumer instance, with a HMAC signer
// - consumerKey and consumerSecret:
// values you should obtain from the ServiceProvider when you register your
// application.
//
// - hashFunc:
// the crypto.Hash to use for signatures
//
// - serviceProvider:
// see the documentation for ServiceProvider for how to create this.
//
// - httpClient:
// Provides a custom implementation of the httpClient used under the hood
// to make the request. This is especially useful if you want to use
// Google App Engine. Can be nil for default.
//
func NewCustomConsumer(consumerKey string, consumerSecret string,
hashFunc crypto.Hash, serviceProvider ServiceProvider,
httpClient *http.Client) *Consumer {
consumer := newConsumer(consumerKey, serviceProvider, httpClient)
consumer.signer = &HMACSigner{
consumerSecret: consumerSecret,
hashFunc: hashFunc,
}
return consumer
}
// Creates a new Consumer instance, with a RSA-SHA1 signer
// - consumerKey:
// value you should obtain from the ServiceProvider when you register your
// application.
//
// - privateKey:
// the private key to use for signatures
//
// - serviceProvider:
// see the documentation for ServiceProvider for how to create this.
//
func NewRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey,
serviceProvider ServiceProvider) *Consumer {
consumer := newConsumer(consumerKey, serviceProvider, nil)
consumer.signer = &RSASigner{
privateKey: privateKey,
hashFunc: crypto.SHA1,
rand: cryptoRand.Reader,
}
return consumer
}
// Creates a new Consumer instance, with a RSA signer
// - consumerKey:
// value you should obtain from the ServiceProvider when you register your
// application.
//
// - privateKey:
// the private key to use for signatures
//
// - hashFunc:
// the crypto.Hash to use for signatures
//
// - serviceProvider:
// see the documentation for ServiceProvider for how to create this.
//
// - httpClient:
// Provides a custom implementation of the httpClient used under the hood
// to make the request. This is especially useful if you want to use
// Google App Engine. Can be nil for default.
//
func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey,
hashFunc crypto.Hash, serviceProvider ServiceProvider,
httpClient *http.Client) *Consumer {
consumer := newConsumer(consumerKey, serviceProvider, httpClient)
consumer.signer = &RSASigner{
privateKey: privateKey,
hashFunc: hashFunc,
rand: cryptoRand.Reader,
}
return consumer
}
// Kicks off the OAuth authorization process.
// - callbackUrl:
// Authorizing a token *requires* redirecting to the service provider. This is the
// URL which the service provider will redirect the user back to after that
// authorization is completed. The service provider will pass back a verification
// code which is necessary to complete the rest of the process (in AuthorizeToken).
// Notes on callbackUrl:
// - Some (all?) service providers allow for setting "oob" (for out-of-band) as a
// callback url. If this is set the service provider will present the
// verification code directly to the user, and you must provide a place for
// them to copy-and-paste it into.
// - Otherwise, the user will be redirected to callbackUrl in the browser, and
// will append a "oauth_verifier=<verifier>" parameter.
//
// This function returns:
// - rtoken:
// A temporary RequestToken, used during the authorization process. You must save
// this since it will be necessary later in the process when calling
// AuthorizeToken().
//
// - url:
// A URL that you should redirect the user to in order that they may authorize you
// to the service provider.
//
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) {
return c.GetRequestTokenAndUrlWithParams(callbackUrl, c.AdditionalParams)
}
func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) {
params := c.baseParams(c.consumerKey, additionalParams)
if callbackUrl != "" {
params.Add(CALLBACK_PARAM, callbackUrl)
}
req := &request{
method: c.serviceProvider.httpMethod(),
url: c.serviceProvider.RequestTokenUrl,
oauthParams: params,
}
if _, err := c.signRequest(req, ""); err != nil { // We don't have a token secret for the key yet
return nil, "", err
}
resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params)
if err != nil {
return nil, "", errors.New("getBody: " + err.Error())
}
requestToken, err := parseRequestToken(*resp)
if err != nil {
return nil, "", errors.New("parseRequestToken: " + err.Error())
}
loginParams := make(url.Values)
for k, v := range c.AdditionalAuthorizationUrlParams {
loginParams.Set(k, v)
}
loginParams.Set(TOKEN_PARAM, requestToken.Token)
loginUrl = c.serviceProvider.AuthorizeTokenUrl + "?" + loginParams.Encode()
return requestToken, loginUrl, nil
}
// After the user has authorized you to the service provider, use this method to turn
// your temporary RequestToken into a permanent AccessToken. You must pass in two values:
// - rtoken:
// The RequestToken returned from GetRequestTokenAndUrl()
//
// - verificationCode:
// The string which passed back from the server, either as the oauth_verifier
// query param appended to callbackUrl *OR* a string manually entered by the user
// if callbackUrl is "oob"
//
// It will return:
// - atoken:
// A permanent AccessToken which can be used to access the user's data (until it is
// revoked by the user or the service provider).
//
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) {
return c.AuthorizeTokenWithParams(rtoken, verificationCode, c.AdditionalParams)
}
func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) {
params := map[string]string{
TOKEN_PARAM: rtoken.Token,
}
if verificationCode != "" {
params[VERIFIER_PARAM] = verificationCode
}
return c.makeAccessTokenRequestWithParams(params, rtoken.Secret, additionalParams)
}
// Use the service provider to refresh the AccessToken for a given session.
// Note that this is only supported for service providers that manage an
// authorization session (e.g. Yahoo).
//
// Most providers do not return the SESSION_HANDLE_PARAM needed to refresh
// the token.
//
// See http://oauth.googlecode.com/svn/spec/ext/session/1.0/drafts/1/spec.html
// for more information.
// - accessToken:
// The AccessToken returned from AuthorizeToken()
//
// It will return:
// - atoken:
// An AccessToken which can be used to access the user's data (until it is
// revoked by the user or the service provider).
//
// - err:
// Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to
// refresh the token, or if an error occurred when making the request.
func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) {
params := make(map[string]string)
sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM]
if !ok {
return nil, errors.New("Missing " + SESSION_HANDLE_PARAM + " in access token.")
}
params[SESSION_HANDLE_PARAM] = sessionHandle
params[TOKEN_PARAM] = accessToken.Token
return c.makeAccessTokenRequest(params, accessToken.Secret)
}
// Use the service provider to obtain an AccessToken for a given session
// - params:
// The access token request paramters.
//
// - secret:
// Secret key to use when signing the access token request.
//
// It will return:
// - atoken
// An AccessToken which can be used to access the user's data (until it is
// revoked by the user or the service provider).
//
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) {
return c.makeAccessTokenRequestWithParams(params, secret, c.AdditionalParams)
}
func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) {
orderedParams := c.baseParams(c.consumerKey, additionalParams)
for key, value := range params {
orderedParams.Add(key, value)
}
req := &request{
method: c.serviceProvider.httpMethod(),
url: c.serviceProvider.AccessTokenUrl,
oauthParams: orderedParams,
}
if _, err := c.signRequest(req, secret); err != nil {
return nil, err
}
resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams)
if err != nil {
return nil, err
}
return parseAccessToken(*resp)
}
type RoundTripper struct {
consumer *Consumer
token *AccessToken
}
func (c *Consumer) MakeRoundTripper(token *AccessToken) (*RoundTripper, error) {
return &RoundTripper{consumer: c, token: token}, nil
}
func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) {
return &http.Client{
Transport: &RoundTripper{consumer: c, token: token},
}, nil
}
// ** DEPRECATED **
// Please call Get on the http client returned by MakeHttpClient instead!
//
// Executes an HTTP Get, authorized via the AccessToken.
// - url:
// The base url, without any query params, which is being accessed
//
// - userParams:
// Any key=value params to be included in the query string
//
// - token:
// The AccessToken returned by AuthorizeToken()
//
// This method returns:
// - resp:
// The HTTP Response resulting from making this request.
//
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token)
}
func encodeUserParams(userParams map[string]string) string {
data := url.Values{}
for k, v := range userParams {
data.Add(k, v)
}
return data.Encode()
}
// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) PostForm(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.PostWithBody(url, "", userParams, token)
}
// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) Post(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.PostWithBody(url, "", userParams, token)
}
// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token)
}
// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and set the "Content-Type" header explicitly in the http.Request)
func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token)
}
// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and set the "Content-Type" header explicitly in the http.Request)
func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token)
}
// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and setup the multipart data explicitly in the http.Request)
func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token)
}
// ** DEPRECATED **
// Please call "Delete" on the http client returned by MakeHttpClient instead
func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token)
}
// ** DEPRECATED **
// Please call "Put" on the http client returned by MakeHttpClient instead
func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token)
}
func (c *Consumer) Debug(enabled bool) {
c.debug = enabled
c.signer.Debug(enabled)
}
type pair struct {
key string
value string
}
type pairs []pair
func (p pairs) Len() int { return len(p) }
func (p pairs) Less(i, j int) bool { return p[i].key < p[j].key }
func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// This function has basically turned into a backwards compatibility layer
// between the old API (where clients explicitly called consumer.Get()
// consumer.Post() etc), and the new API (which takes actual http.Requests)
//
// So, here we construct the appropriate HTTP request for the inputs.
func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
urlObject, err := url.Parse(urlString)
if err != nil {
return nil, err
}
request := &http.Request{
Method: method,
URL: urlObject,
Header: http.Header{},
Body: body,
ContentLength: int64(contentLength),
}
vals := url.Values{}
for k, v := range userParams {
vals.Add(k, v)
}
if dataLocation != LOC_BODY {
request.URL.RawQuery = vals.Encode()
request.URL.RawQuery = strings.Replace(
request.URL.RawQuery, ";", "%3B", -1)
} else {
// TODO(mrjones): validate that we're not overrideing an exising body?
request.ContentLength = int64(len(vals.Encode()))
if request.ContentLength == 0 {
request.Body = nil
} else {
request.Body = ioutil.NopCloser(strings.NewReader(vals.Encode()))
}
}
for k, vs := range c.AdditionalHeaders {
for _, v := range vs {
request.Header.Set(k, v)
}
}
if dataLocation == LOC_BODY {
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if dataLocation == LOC_JSON {
request.Header.Set("Content-Type", "application/json")
}
if dataLocation == LOC_XML {
request.Header.Set("Content-Type", "application/xml")
}
if dataLocation == LOC_MULTIPART {
pipeReader, pipeWriter := io.Pipe()
writer := multipart.NewWriter(pipeWriter)
if request.URL.Host == "www.mrjon.es" &&
request.URL.Path == "/unittest" {
writer.SetBoundary("UNITTESTBOUNDARY")
}
go func(body io.Reader) {
part, err := writer.CreateFormFile(multipartName, "/no/matter")
if err != nil {
writer.Close()
pipeWriter.CloseWithError(err)
return
}
_, err = io.Copy(part, body)
if err != nil {
writer.Close()
pipeWriter.CloseWithError(err)
return
}
writer.Close()
pipeWriter.Close()
}(body)
request.Body = pipeReader
request.Header.Set("Content-Type", writer.FormDataContentType())
}
rt := RoundTripper{consumer: c, token: token}
resp, err = rt.RoundTrip(request)
if err != nil {
return resp, err
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close()
bytes, _ := ioutil.ReadAll(resp.Body)
return resp, HTTPExecuteError{
RequestHeaders: "",
ResponseBodyBytes: bytes,
Status: resp.Status,
StatusCode: resp.StatusCode,
}
}
return resp, nil
}
// cloneReq clones the src http.Request, making deep copies of the Header and
// the URL but shallow copies of everything else
func cloneReq(src *http.Request) *http.Request {
dst := &http.Request{}
*dst = *src
dst.Header = make(http.Header, len(src.Header))
for k, s := range src.Header {
dst.Header[k] = append([]string(nil), s...)
}
if src.URL != nil {
dst.URL = cloneURL(src.URL)
}
return dst
}
// cloneURL shallow clones the src *url.URL
func cloneURL(src *url.URL) *url.URL {
dst := &url.URL{}
*dst = *src
return dst
}
func canonicalizeUrl(u *url.URL) string {
var buf bytes.Buffer
buf.WriteString(u.Scheme)
buf.WriteString("://")
buf.WriteString(u.Host)
buf.WriteString(u.Path)
return buf.String()
}
func getBody(request *http.Request) ([]byte, error) {
if request.Body == nil {
return nil, nil
}
defer request.Body.Close()
originalBody, err := ioutil.ReadAll(request.Body)
if err != nil {
return nil, err
}
// We have to re-install the body (because we've ruined it by reading it).
if len(originalBody) > 0 {
request.Body = ioutil.NopCloser(bytes.NewReader(originalBody))
} else {
request.Body = nil
}
return originalBody, nil
}
func parseBody(request *http.Request) (map[string][]string, error) {
userParams := map[string][]string{}
// TODO(mrjones): factor parameter extraction into a separate method
if request.Header.Get("Content-Type") !=
"application/x-www-form-urlencoded" {
// Most of the time we get parameters from the query string:
for k, vs := range request.URL.Query() {
userParams[k] = vs
}
} else {
// x-www-form-urlencoded parameters come from the body instead:
body, err := getBody(request)
if err != nil {
return nil, err
}
params, err := url.ParseQuery(string(body))
if err != nil {
return nil, err
}
for k, vs := range params {
userParams[k] = vs
}
}
return userParams, nil
}
func paramsToSortedPairs(params map[string][]string) pairs {
// Sort parameters alphabetically
paramPairs := pairs([]pair{})
for key, values := range params {
for _, value := range values {
paramPairs = append(paramPairs, pair{key: key, value: value})
}
}
sort.Sort(paramPairs)
return paramPairs
}
func calculateBodyHash(request *http.Request, s signer) (string, error) {
if request.Header.Get("Content-Type") ==
"application/x-www-form-urlencoded" {
return "", nil
}
var body []byte
if request.Body != nil {
var err error
body, err = getBody(request)
if err != nil {
return "", err
}
}
h := s.HashFunc().New()
h.Write(body)
rawSignature := h.Sum(nil)
return base64.StdEncoding.EncodeToString(rawSignature), nil
}
func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, error) {
serverRequest := cloneReq(userRequest)
allParams := rt.consumer.baseParams(
rt.consumer.consumerKey, rt.consumer.AdditionalParams)
// Do not add the "oauth_token" parameter, if the access token has not been
// specified. By omitting this parameter when it is not specified, allows
// two-legged OAuth calls.
if len(rt.token.Token) > 0 {
allParams.Add(TOKEN_PARAM, rt.token.Token)
}
if rt.consumer.serviceProvider.BodyHash {
bodyHash, err := calculateBodyHash(serverRequest, rt.consumer.signer)
if err != nil {
return nil, err
}
if bodyHash != "" {
allParams.Add(BODY_HASH_PARAM, bodyHash)
}
}
authParams := allParams.Clone()
// TODO(mrjones): put these directly into the paramPairs below?
userParams, err := parseBody(serverRequest)
if err != nil {
return nil, err
}
paramPairs := paramsToSortedPairs(userParams)
for i := range paramPairs {
allParams.Add(paramPairs[i].key, paramPairs[i].value)
}
signingURL := cloneURL(serverRequest.URL)
if host := serverRequest.Host; host != "" {
signingURL.Host = host
}
baseString := rt.consumer.requestString(serverRequest.Method, canonicalizeUrl(signingURL), allParams)
signature, err := rt.consumer.signer.Sign(baseString, rt.token.Secret)
if err != nil {
return nil, err
}
authParams.Add(SIGNATURE_PARAM, signature)
// Set auth header.
oauthHdr := OAUTH_HEADER
for pos, key := range authParams.Keys() {
for innerPos, value := range authParams.Get(key) {
if pos+innerPos > 0 {
oauthHdr += ","
}
oauthHdr += key + "=\"" + value + "\""
}
}
serverRequest.Header.Add(HTTP_AUTH_HEADER, oauthHdr)
if rt.consumer.debug {
fmt.Printf("Request: %v\n", serverRequest)
}
resp, err := rt.consumer.HttpClient.Do(serverRequest)
if err != nil {
return resp, err
}
return resp, nil
}
func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token)
}
type request struct {
method string
url string
oauthParams *OrderedParams
userParams map[string]string
}
type HttpClient interface {
Do(req *http.Request) (resp *http.Response, err error)
}
type clock interface {
Seconds() int64
Nanos() int64
}
type nonceGenerator interface {
Int63() int64
}
type key interface {
String() string
}
type signer interface {
Sign(message string, tokenSecret string) (string, error)
Verify(message string, signature string) error
SignatureMethod() string
HashFunc() crypto.Hash
Debug(enabled bool)
}
type defaultClock struct{}
func (*defaultClock) Seconds() int64 {
return time.Now().Unix()
}
func (*defaultClock) Nanos() int64 {
return time.Now().UnixNano()
}
func (c *Consumer) signRequest(req *request, tokenSecret string) (*request, error) {
baseString := c.requestString(req.method, req.url, req.oauthParams)
signature, err := c.signer.Sign(baseString, tokenSecret)
if err != nil {
return nil, err
}
req.oauthParams.Add(SIGNATURE_PARAM, signature)
return req, nil
}
// Obtains an AccessToken from the response of a service provider.
// - data:
// The response body.
//
// This method returns:
// - atoken:
// The AccessToken generated from the response body.
//
// - err:
// Set if an AccessToken could not be parsed from the given input.
func parseAccessToken(data string) (atoken *AccessToken, err error) {
parts, err := url.ParseQuery(data)
if err != nil {
return nil, err
}
tokenParam := parts[TOKEN_PARAM]
parts.Del(TOKEN_PARAM)
if len(tokenParam) < 1 {
return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " +
"Full response body: '" + data + "'")
}
tokenSecretParam := parts[TOKEN_SECRET_PARAM]
parts.Del(TOKEN_SECRET_PARAM)
if len(tokenSecretParam) < 1 {
return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." +
"Full response body: '" + data + "'")
}
additionalData := parseAdditionalData(parts)
return &AccessToken{tokenParam[0], tokenSecretParam[0], additionalData}, nil
}
func parseRequestToken(data string) (*RequestToken, error) {
parts, err := url.ParseQuery(data)
if err != nil {
return nil, err
}
tokenParam := parts[TOKEN_PARAM]
if len(tokenParam) < 1 {
return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " +
"Full response body: '" + data + "'")
}
tokenSecretParam := parts[TOKEN_SECRET_PARAM]
if len(tokenSecretParam) < 1 {
return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." +
"Full response body: '" + data + "'")
}
return &RequestToken{tokenParam[0], tokenSecretParam[0]}, nil
}
func (c *Consumer) baseParams(consumerKey string, additionalParams map[string]string) *OrderedParams {
params := NewOrderedParams()
params.Add(VERSION_PARAM, OAUTH_VERSION)
params.Add(SIGNATURE_METHOD_PARAM, c.signer.SignatureMethod())
params.Add(TIMESTAMP_PARAM, strconv.FormatInt(c.clock.Seconds(), 10))
params.Add(NONCE_PARAM, strconv.FormatInt(c.nonceGenerator.Int63(), 10))
params.Add(CONSUMER_KEY_PARAM, consumerKey)
for key, value := range additionalParams {
params.Add(key, value)
}
return params
}
func parseAdditionalData(parts url.Values) map[string]string {
params := make(map[string]string)
for key, value := range parts {
if len(value) > 0 {
params[key] = value[0]
}
}
return params
}
type HMACSigner struct {
consumerSecret string
hashFunc crypto.Hash
debug bool
}
func (s *HMACSigner) Debug(enabled bool) {
s.debug = enabled
}
func (s *HMACSigner) Sign(message string, tokenSecret string) (string, error) {
key := escape(s.consumerSecret) + "&" + escape(tokenSecret)
if s.debug {
fmt.Println("Signing:", message)
fmt.Println("Key:", key)
}
h := hmac.New(s.HashFunc().New, []byte(key))
h.Write([]byte(message))
rawSignature := h.Sum(nil)
base64signature := base64.StdEncoding.EncodeToString(rawSignature)
if s.debug {
fmt.Println("Base64 signature:", base64signature)
}
return base64signature, nil
}
func (s *HMACSigner) Verify(message string, signature string) error {
if s.debug {
fmt.Println("Verifying Base64 signature:", signature)
}
validSignature, err := s.Sign(message, "")
if err != nil {
return err
}
if validSignature != signature {
decodedSigniture, _ := url.QueryUnescape(signature)
if validSignature != decodedSigniture {
return fmt.Errorf("signature did not match")
}
}
return nil
}
func (s *HMACSigner) SignatureMethod() string {
return SIGNATURE_METHOD_HMAC + HASH_METHOD_MAP[s.HashFunc()]
}
func (s *HMACSigner) HashFunc() crypto.Hash {
return s.hashFunc
}
type RSASigner struct {
debug bool
rand io.Reader
privateKey *rsa.PrivateKey
hashFunc crypto.Hash
}
func (s *RSASigner) Debug(enabled bool) {
s.debug = enabled
}
func (s *RSASigner) Sign(message string, tokenSecret string) (string, error) {
if s.debug {
fmt.Println("Signing:", message)
}
h := s.HashFunc().New()
h.Write([]byte(message))
digest := h.Sum(nil)
signature, err := rsa.SignPKCS1v15(s.rand, s.privateKey, s.HashFunc(), digest)
if err != nil {
return "", nil
}
base64signature := base64.StdEncoding.EncodeToString(signature)
if s.debug {
fmt.Println("Base64 signature:", base64signature)
}
return base64signature, nil
}
func (s *RSASigner) Verify(message string, base64signature string) error {
if s.debug {
fmt.Println("Verifying:", message)
fmt.Println("Verifying Base64 signature:", base64signature)
}
h := s.HashFunc().New()
h.Write([]byte(message))
digest := h.Sum(nil)
signature, err := base64.StdEncoding.DecodeString(base64signature)
if err != nil {
return err
}
return rsa.VerifyPKCS1v15(&s.privateKey.PublicKey, s.HashFunc(), digest, signature)
}
func (s *RSASigner) SignatureMethod() string {
return SIGNATURE_METHOD_RSA + HASH_METHOD_MAP[s.HashFunc()]
}
func (s *RSASigner) HashFunc() crypto.Hash {
return s.hashFunc
}
func escape(s string) string {
t := make([]byte, 0, 3*len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if isEscapable(c) {
t = append(t, '%')
t = append(t, "0123456789ABCDEF"[c>>4])
t = append(t, "0123456789ABCDEF"[c&15])
} else {
t = append(t, s[i])
}
}
return string(t)
}
func isEscapable(b byte) bool {
return !('A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' || b == '-' || b == '.' || b == '_' || b == '~')
}
func (c *Consumer) requestString(method string, url string, params *OrderedParams) string {
result := method + "&" + escape(url)
for pos, key := range params.Keys() {
for innerPos, value := range params.Get(key) {
if pos+innerPos == 0 {
result += "&"
} else {
result += escape("&")
}
result += escape(fmt.Sprintf("%s=%s", key, value))
}
}
return result
}
func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) {
resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams)
if err != nil {
return nil, errors.New("httpExecute: " + err.Error())
}
bodyBytes, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, errors.New("ReadAll: " + err.Error())
}
bodyStr := string(bodyBytes)
if c.debug {
fmt.Printf("STATUS: %d %s\n", resp.StatusCode, resp.Status)
fmt.Println("BODY RESPONSE: " + bodyStr)
}
return &bodyStr, nil
}
// HTTPExecuteError signals that a call to httpExecute failed.
type HTTPExecuteError struct {
// RequestHeaders provides a stringified listing of request headers.
RequestHeaders string
// ResponseBodyBytes is the response read into a byte slice.
ResponseBodyBytes []byte
// Status is the status code string response.
Status string
// StatusCode is the parsed status code.
StatusCode int
}
// Error provides a printable string description of an HTTPExecuteError.
func (e HTTPExecuteError) Error() string {
return "HTTP response is not 200/OK as expected. Actual response: \n" +
"\tResponse Status: '" + e.Status + "'\n" +
"\tResponse Code: " + strconv.Itoa(e.StatusCode) + "\n" +
"\tResponse Body: " + string(e.ResponseBodyBytes) + "\n" +
"\tRequest Headers: " + e.RequestHeaders
}
func (c *Consumer) httpExecute(
method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) {
// Create base request.
req, err := http.NewRequest(method, urlStr, body)
if err != nil {
return nil, errors.New("NewRequest failed: " + err.Error())
}
// Set auth header.
req.Header = http.Header{}
oauthHdr := "OAuth "
for pos, key := range oauthParams.Keys() {
for innerPos, value := range oauthParams.Get(key) {
if pos+innerPos > 0 {
oauthHdr += ","
}
oauthHdr += key + "=\"" + value + "\""
}
}
req.Header.Add("Authorization", oauthHdr)
// Add additional custom headers
for key, vals := range c.AdditionalHeaders {
for _, val := range vals {
req.Header.Add(key, val)
}
}
// Set contentType if passed.
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
// Set contentLength if passed.
if contentLength > 0 {
req.Header.Set("Content-Length", strconv.Itoa(contentLength))
}
if c.debug {
fmt.Printf("Request: %v\n", req)
}
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, errors.New("Do: " + err.Error())
}
debugHeader := ""
for k, vals := range req.Header {
for _, val := range vals {
debugHeader += "[key: " + k + ", val: " + val + "]"
}
}
// StatusMultipleChoices is 300, any 2xx response should be treated as success
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close()
bytes, _ := ioutil.ReadAll(resp.Body)
return resp, HTTPExecuteError{
RequestHeaders: debugHeader,
ResponseBodyBytes: bytes,
Status: resp.Status,
StatusCode: resp.StatusCode,
}
}
return resp, err
}
//
// String Sorting helpers
//
type ByValue []string
func (a ByValue) Len() int {
return len(a)
}
func (a ByValue) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a ByValue) Less(i, j int) bool {
return a[i] < a[j]
}
//
// ORDERED PARAMS
//
type OrderedParams struct {
allParams map[string][]string
keyOrdering []string
}
func NewOrderedParams() *OrderedParams {
return &OrderedParams{
allParams: make(map[string][]string),
keyOrdering: make([]string, 0),
}
}
func (o *OrderedParams) Get(key string) []string {
sort.Sort(ByValue(o.allParams[key]))
return o.allParams[key]
}
func (o *OrderedParams) Keys() []string {
sort.Sort(o)
return o.keyOrdering
}
func (o *OrderedParams) Add(key, value string) {
o.AddUnescaped(key, escape(value))
}
func (o *OrderedParams) AddUnescaped(key, value string) {
if _, exists := o.allParams[key]; !exists {
o.keyOrdering = append(o.keyOrdering, key)
o.allParams[key] = make([]string, 1)
o.allParams[key][0] = value
} else {
o.allParams[key] = append(o.allParams[key], value)
}
}
func (o *OrderedParams) Len() int {
return len(o.keyOrdering)
}
func (o *OrderedParams) Less(i int, j int) bool {
return o.keyOrdering[i] < o.keyOrdering[j]
}
func (o *OrderedParams) Swap(i int, j int) {
o.keyOrdering[i], o.keyOrdering[j] = o.keyOrdering[j], o.keyOrdering[i]
}
func (o *OrderedParams) Clone() *OrderedParams {
clone := NewOrderedParams()
for _, key := range o.Keys() {
for _, value := range o.Get(key) {
clone.AddUnescaped(key, value)
}
}
return clone
}