blob: 90fdc745c06a07931424ccc988b4df2f14ef1095 [file] [log] [blame]
# Copyright (C) 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import concurrent.futures
import google.auth
import google.auth.transport.requests
import json
import logging
import os
import requests
from base64 import b64encode
from datetime import datetime
from config import PROJECT
# Thread pool for making http requests asynchronosly.
thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=8)
# Caller has to initialize this
SCOPES = []
cached_gerrit_creds = None
cached_oauth2_creds = None
class ConcurrentModificationError(Exception):
pass
def get_access_token():
global cached_oauth2_creds
creds = cached_oauth2_creds
if creds is None or not creds.valid or creds.expired:
creds, _project = google.auth.default(scopes=SCOPES)
request = google.auth.transport.requests.Request()
creds.refresh(request)
cached_oauth2_creds = creds
return creds.token
def get_gerrit_credentials():
'''Retrieve the credentials used to authenticate Gerrit requests
Returns a tuple (user, gitcookie). These fields are obtained from the Gerrit
'New HTTP password' page which generates a .gitcookie file and stored in the
project datastore.
user: typically looks like git-user.gmail.com.
gitcookie: is the password after the = token.
'''
global cached_gerrit_creds
if cached_gerrit_creds is None:
body = {'query': {'kind': [{'name': 'GerritAuth'}]}}
res = req(
'POST',
'https://datastore.googleapis.com/v1/projects/%s:runQuery' % PROJECT,
body=body)
auth = res['batch']['entityResults'][0]['entity']['properties']
user = auth['user']['stringValue']
gitcookie = auth['gitcookie']['stringValue']
cached_gerrit_creds = user, gitcookie
return cached_gerrit_creds
async def req_async(method, url, body=None, gerrit=False):
loop = asyncio.get_running_loop()
# run_in_executor cannot take kwargs, we need to stick with order.
return await loop.run_in_executor(thread_pool, req, method, url, body, gerrit,
False, None)
def req(method, url, body=None, gerrit=False, req_etag=False, etag=None):
'''Helper function to handle authenticated HTTP requests.
Cloud API and Gerrit require two different types of authentication and as
such need to be handled differently. The HTTP connection is cached in the
TLS slot to avoid refreshing oauth tokens too often for back-to-back requests.
Appengine takes care of clearing the TLS slot upon each frontend request so
these connections won't be recycled for too long.
'''
hdr = {'Content-Type': 'application/json; charset=UTF-8'}
if gerrit:
creds = '%s:%s' % get_gerrit_credentials()
auth_header = 'Basic ' + b64encode(creds.encode('utf-8')).decode('utf-8')
elif SCOPES:
auth_header = 'Bearer ' + get_access_token()
logging.debug('%s %s [gerrit=%d]', method, url, gerrit)
hdr['Authorization'] = auth_header
if req_etag:
hdr['X-Firebase-ETag'] = 'true'
if etag:
hdr['if-match'] = etag
body = None if body is None else json.dumps(body)
resp = requests.request(method, url, headers=hdr, data=body, timeout=60)
res = resp.content.decode('utf-8')
resp_etag = resp.headers.get('etag')
if resp.status_code == 200:
# [4:] is to strip Gerrit XSSI projection prefix.
res = json.loads(res[4:] if gerrit else res)
return (res, resp_etag) if req_etag else res
elif resp.status_code == 412:
raise ConcurrentModificationError()
else:
raise Exception(resp, res)
# Datetime functions to deal with the fact that Javascript expects a trailing
# 'Z' (Z == 'Zulu' == UTC) for timestamps.
def parse_iso_time(time_str):
return datetime.strptime(time_str, r'%Y-%m-%dT%H:%M:%SZ')
def utc_now_iso(utcnow=None):
return (utcnow or datetime.utcnow()).strftime(r'%Y-%m-%dT%H:%M:%SZ')
def defer(coro):
loop = asyncio.get_event_loop()
task = loop.create_task(coro)
task.set_name(coro.cr_code.co_name)
return task
def init_logging():
logging.basicConfig(
format='%(levelname)-8s %(asctime)s %(message)s',
level=logging.DEBUG if os.getenv('VERBOSE') else logging.INFO,
datefmt=r'%Y-%m-%d %H:%M:%S')