Files
dify/api/controllers/console/auth/oauth_server.py

212 lines
7.4 KiB
Python
Raw Permalink Normal View History

from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import console_ns
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
return view(self, oauth_provider_app, *args, **kwargs)
return decorated
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
response = jsonify({"error": "Authorization header is required"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
parts = authorization_header.strip().split(None, 1)
if len(parts) != 2:
response = jsonify({"error": "Invalid Authorization header format"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
token_type = parts[0].strip()
if token_type.lower() != "bearer":
response = jsonify({"error": "token_type is invalid"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
access_token = parts[1].strip()
if not access_token:
response = jsonify({"error": "access_token is required"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
if not account:
response = jsonify({"error": "access_token or client_id is invalid"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
return view(self, oauth_provider_app, account, *args, **kwargs)
return decorated
@console_ns.route("/oauth/provider")
class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
payload = OAuthProviderRequest.model_validate(request.get_json())
redirect_uri = payload.redirect_uri
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
return jsonable_encoder(
{
"app_icon": oauth_provider_app.app_icon,
"app_label": oauth_provider_app.app_label,
"scope": oauth_provider_app.scope,
}
)
@console_ns.route("/oauth/provider/authorize")
class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{
"code": code,
}
)
@console_ns.route("/oauth/provider/token")
class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
payload = OAuthTokenRequest.model_validate(request.get_json())
try:
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account")
class OAuthServerUserAccountApi(Resource):
@setup_required
@oauth_server_client_id_required
@oauth_server_access_token_required
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
return jsonable_encoder(
{
"name": account.name,
"email": account.email,
"avatar": account.avatar,
"interface_language": account.interface_language,
"timezone": account.timezone,
}
)