1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// The Args type parses the arguments to the `#[throws]` macro.
//
// It is also responsible for transforming the return type by injecting
// the return type and the error type into the wrapper type.

use proc_macro2::Span;
use syn::{GenericArgument, Path, PathArguments, ReturnType, Token, Type};
use syn::parse::*;

const WRAPPER_MUST_BE_PATH: &str = "Wrapper type must be a normal path type";

pub struct Args {
    error: Option<Type>,
    wrapper: Option<Type>,
}

impl Args {
    pub fn ret(&mut self, ret: ReturnType) -> ReturnType {
        let (arrow, ret) = match ret {
            ReturnType::Default         => (arrow(), unit()),
            ReturnType::Type(arrow, ty) => (arrow, *ty),
        };
        ReturnType::Type(arrow, Box::new(self.inject_to_wrapper(ret)))

    }

    fn inject_to_wrapper(&mut self, ret: Type) -> Type {
        if let Some(Type::Path(mut wrapper)) = self.wrapper.take() {
            let types = if let Some(error) = self.error.take() {
                vec![ret, error].into_iter().map(GenericArgument::Type)
            } else {
                vec![ret].into_iter().map(GenericArgument::Type)
            };

            match innermost_path_arguments(&mut wrapper.path) {
                args @ &mut PathArguments::None    => {
                    *args = PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
                        colon2_token: None,
                        lt_token: Token![<](Span::call_site()),
                        args: types.collect(),
                        gt_token: Token![>](Span::call_site()),
                    });
                }
                PathArguments::AngleBracketed(args) => args.args.extend(types),
                _   => panic!(WRAPPER_MUST_BE_PATH)
            }

            Type::Path(wrapper)
        } else { panic!(WRAPPER_MUST_BE_PATH) }
    }
}

impl Parse for Args {
    fn parse(input: ParseStream) -> Result<Args> {
        if input.is_empty() {
            return Ok(Args {
                error: Some(default_error()),
                wrapper: Some(result()),
            })
        }

        let error = match input.peek(Token![as]) {
            true    => None,
            false   => {
                let error = input.parse()?;
                Some(match error {
                    Type::Infer(_)  => default_error(),
                    _               => error,
                })
            }
        };

        let wrapper = Some(match input.parse::<Token![as]>().is_ok() {
            true    => input.parse()?,
            false   => result(),
        });

        Ok(Args { error, wrapper })
    }
}

fn innermost_path_arguments(path: &mut Path) -> &mut PathArguments {
    let arguments = &mut path.segments.last_mut().expect(WRAPPER_MUST_BE_PATH).arguments;
    match arguments {
        PathArguments::None                 => arguments,
        PathArguments::AngleBracketed(args) => {
            match args.args.last_mut() {
                Some(GenericArgument::Type(Type::Path(inner)))  => {
                    innermost_path_arguments(&mut inner.path)
                }
                // Bizarre cases like `#[throw(_ as MyTryType<'a>)]` just not supported currently
                _   => panic!("Certain strange wrapper types not supported"),
            }
        }
        _                                   => panic!(WRAPPER_MUST_BE_PATH)
    }
}

fn arrow() -> syn::token::RArrow {
    Token![->](Span::call_site())
}

fn unit() -> Type {
    syn::parse_str("()").unwrap()
}

fn result() -> Type {
    syn::parse_str("::core::result::Result").unwrap()
}

fn default_error() -> Type {
    syn::parse_str("Error").unwrap()
}